Repository: TransformerOptimus/SuperAGI
Branch: main
Commit: c3c1982e7bd6
Files: 630
Total size: 1.9 MB
Directory structure:
gitextract_6lrw3u4j/
├── .do/
│ ├── app.yaml
│ └── deploy.template.yaml
├── .dockerignore
├── .gitattributes
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ └── 1.BUG_REPORT.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ ├── ci.yml
│ └── codeql.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── Dockerfile-gpu
├── DockerfileCelery
├── DockerfileRedis
├── LICENSE
├── README.MD
├── alembic.ini
├── cli2.py
├── config_template.yaml
├── docker-compose-dev.yaml
├── docker-compose-gpu.yml
├── docker-compose.image.example.yaml
├── docker-compose.yaml
├── entrypoint.sh
├── entrypoint_celery.sh
├── gui/
│ ├── .dockerignore
│ ├── .eslintrc.json
│ ├── Dockerfile
│ ├── DockerfileProd
│ ├── README.md
│ ├── app/
│ │ ├── globals.css
│ │ └── layout.js
│ ├── jsconfig.json
│ ├── next.config.js
│ ├── package.json
│ ├── pages/
│ │ ├── Content/
│ │ │ ├── APM/
│ │ │ │ ├── Apm.module.css
│ │ │ │ ├── ApmDashboard.js
│ │ │ │ └── BarGraph.js
│ │ │ ├── Agents/
│ │ │ │ ├── ActionConsole.js
│ │ │ │ ├── ActivityFeed.js
│ │ │ │ ├── AgentCreate.js
│ │ │ │ ├── AgentSchedule.js
│ │ │ │ ├── AgentTemplatesList.js
│ │ │ │ ├── AgentWorkspace.js
│ │ │ │ ├── Agents.js
│ │ │ │ ├── Agents.module.css
│ │ │ │ ├── Details.js
│ │ │ │ ├── ResourceList.js
│ │ │ │ ├── ResourceManager.js
│ │ │ │ ├── RunHistory.js
│ │ │ │ ├── TaskQueue.js
│ │ │ │ └── react-datetime.css
│ │ │ ├── Knowledge/
│ │ │ │ ├── AddKnowledge.js
│ │ │ │ ├── Knowledge.js
│ │ │ │ ├── Knowledge.module.css
│ │ │ │ ├── KnowledgeDetails.js
│ │ │ │ └── KnowledgeForm.js
│ │ │ ├── Marketplace/
│ │ │ │ ├── AgentTemplate.js
│ │ │ │ ├── KnowledgeTemplate.js
│ │ │ │ ├── Market.js
│ │ │ │ ├── Market.module.css
│ │ │ │ ├── MarketAgent.js
│ │ │ │ ├── MarketKnowledge.js
│ │ │ │ ├── MarketTools.js
│ │ │ │ ├── MarketplacePublic.js
│ │ │ │ └── ToolkitTemplate.js
│ │ │ ├── Models/
│ │ │ │ ├── AddModel.js
│ │ │ │ ├── AddModelMarketPlace.js
│ │ │ │ ├── MarketModels.js
│ │ │ │ ├── ModelDetails.js
│ │ │ │ ├── ModelForm.js
│ │ │ │ ├── ModelInfo.js
│ │ │ │ ├── ModelMetrics.js
│ │ │ │ ├── ModelTemplate.js
│ │ │ │ └── Models.js
│ │ │ └── Toolkits/
│ │ │ ├── AddTool.js
│ │ │ ├── Metrics.js
│ │ │ ├── Tool.module.css
│ │ │ ├── ToolkitWorkspace.js
│ │ │ └── Toolkits.js
│ │ ├── Dashboard/
│ │ │ ├── Content.js
│ │ │ ├── Dashboard.module.css
│ │ │ ├── Settings/
│ │ │ │ ├── AddDatabase.js
│ │ │ │ ├── ApiKeys.js
│ │ │ │ ├── Database.js
│ │ │ │ ├── DatabaseDetails.js
│ │ │ │ ├── Model.js
│ │ │ │ ├── Settings.js
│ │ │ │ └── Webhooks.js
│ │ │ ├── SideBar.js
│ │ │ └── TopBar.js
│ │ ├── _app.css
│ │ ├── _app.js
│ │ └── api/
│ │ ├── DashboardService.js
│ │ └── apiConfig.js
│ └── utils/
│ ├── eventBus.js
│ └── utils.js
├── install_tool_dependencies.sh
├── local-llm
├── local-llm-gpu
├── main.py
├── migrations/
│ ├── README
│ ├── env.py
│ ├── script.py.mako
│ └── versions/
│ ├── 1d54db311055_add_permissions.py
│ ├── 2cc1179834b0_agent_executions_modified.py
│ ├── 2f97c068fab9_resource_modified.py
│ ├── 2fbd6472112c_add_feed_group_id_to_execution_and_feed.py
│ ├── 3356a2f89a33_added_configurations_table.py
│ ├── 35e47f20475b_renamed_tokens_calls.py
│ ├── 3867bb00a495_added_first_login_source.py
│ ├── 40affbf3022b_add_filter_colume_in_webhooks.py
│ ├── 446884dcae58_add_api_key_and_web_hook.py
│ ├── 44b0d6f2d1b3_init_models.py
│ ├── 467e85d5e1cd_updated_resources_added_exec_id.py
│ ├── 516ecc1c723d_adding_marketplace_template_id_to_agent_.py
│ ├── 5184645e9f12_add_question_to_agent_execution_.py
│ ├── 520aa6776347_create_models_config.py
│ ├── 598cfb37292a_adding_agent_templates.py
│ ├── 5d5f801f28e7_create_model_table.py
│ ├── 661ec8a4c32e_open_ai_error_handling.py
│ ├── 71e3980d55f5_knowledge_and_vector_dbs.py
│ ├── 7a3e336c0fba_added_tools_related_models.py
│ ├── 83424de1347e_added_agent_execution_config.py
│ ├── 8962bed0d809_creating_agent_templates.py
│ ├── 9270eb5a8475_local_llms.py
│ ├── 9419b3340af7_create_agent_workflow.py
│ ├── a91808a89623_added_resources.py
│ ├── ba60b12ae109_create_agent_scheduler.py
│ ├── be1d922bf2ad_create_call_logs_table.py
│ ├── c02f3d759bf3_add_summary_to_resource.py
│ ├── c4f2f6ba602a_agent_workflow_wait_step.py
│ ├── c5c19944c90c_create_oauth_tokens.py
│ ├── cac478732572_delete_agent_feature.py
│ ├── d8315244ea43_updated_tool_configs.py
│ ├── d9b3436197eb_renaming_templates.py
│ ├── e39295ec089c_creating_events.py
│ └── fe234ea6e9bc_modify_agent_workflow_tables.py
├── nginx/
│ └── default.conf
├── package.json
├── requirements.txt
├── run.bat
├── run.sh
├── run_gui.py
├── run_gui.sh
├── superagi/
│ ├── __init__.py
│ ├── agent/
│ │ ├── __init__.py
│ │ ├── agent_iteration_step_handler.py
│ │ ├── agent_message_builder.py
│ │ ├── agent_prompt_builder.py
│ │ ├── agent_prompt_template.py
│ │ ├── agent_tool_step_handler.py
│ │ ├── agent_workflow_step_wait_handler.py
│ │ ├── common_types.py
│ │ ├── output_handler.py
│ │ ├── output_parser.py
│ │ ├── prompts/
│ │ │ ├── agent_queue_input.txt
│ │ │ ├── agent_recursive_summary.txt
│ │ │ ├── agent_summary.txt
│ │ │ ├── agent_tool_input.txt
│ │ │ ├── agent_tool_output.txt
│ │ │ ├── analyse_task.txt
│ │ │ ├── create_tasks.txt
│ │ │ ├── initialize_tasks.txt
│ │ │ ├── prioritize_tasks.txt
│ │ │ └── superagi.txt
│ │ ├── queue_step_handler.py
│ │ ├── task_queue.py
│ │ ├── tool_builder.py
│ │ ├── tool_executor.py
│ │ ├── types/
│ │ │ ├── __init__.py
│ │ │ ├── agent_execution_status.py
│ │ │ ├── agent_workflow_step_action_types.py
│ │ │ └── wait_step_status.py
│ │ └── workflow_seed.py
│ ├── apm/
│ │ ├── __init__.py
│ │ ├── analytics_helper.py
│ │ ├── call_log_helper.py
│ │ ├── event_handler.py
│ │ ├── knowledge_handler.py
│ │ └── tools_handler.py
│ ├── config/
│ │ ├── __init__.py
│ │ └── config.py
│ ├── controllers/
│ │ ├── __init__.py
│ │ ├── agent.py
│ │ ├── agent_execution.py
│ │ ├── agent_execution_config.py
│ │ ├── agent_execution_feed.py
│ │ ├── agent_execution_permission.py
│ │ ├── agent_template.py
│ │ ├── agent_workflow.py
│ │ ├── analytics.py
│ │ ├── api/
│ │ │ └── agent.py
│ │ ├── api_key.py
│ │ ├── budget.py
│ │ ├── config.py
│ │ ├── google_oauth.py
│ │ ├── knowledge_configs.py
│ │ ├── knowledges.py
│ │ ├── marketplace_stats.py
│ │ ├── models_controller.py
│ │ ├── organisation.py
│ │ ├── project.py
│ │ ├── resources.py
│ │ ├── tool.py
│ │ ├── tool_config.py
│ │ ├── toolkit.py
│ │ ├── twitter_oauth.py
│ │ ├── types/
│ │ │ ├── agent_execution_config.py
│ │ │ ├── agent_publish_config.py
│ │ │ ├── agent_schedule.py
│ │ │ ├── agent_with_config.py
│ │ │ ├── agent_with_config_schedule.py
│ │ │ └── models_types.py
│ │ ├── user.py
│ │ ├── vector_db_indices.py
│ │ ├── vector_dbs.py
│ │ └── webhook.py
│ ├── helper/
│ │ ├── agent_schedule_helper.py
│ │ ├── auth.py
│ │ ├── calendar_date.py
│ │ ├── encyption_helper.py
│ │ ├── error_handler.py
│ │ ├── feed_parser.py
│ │ ├── github_helper.py
│ │ ├── google_calendar_creds.py
│ │ ├── google_search.py
│ │ ├── google_serp.py
│ │ ├── imap_email.py
│ │ ├── json_cleaner.py
│ │ ├── llm_loader.py
│ │ ├── models_helper.py
│ │ ├── prompt_reader.py
│ │ ├── read_email.py
│ │ ├── resource_helper.py
│ │ ├── s3_helper.py
│ │ ├── time_helper.py
│ │ ├── token_counter.py
│ │ ├── tool_helper.py
│ │ ├── twitter_helper.py
│ │ ├── twitter_tokens.py
│ │ ├── validate_csv.py
│ │ ├── webhook_manager.py
│ │ └── webpage_extractor.py
│ ├── image_llms/
│ │ ├── __init__.py
│ │ ├── base_image_llm.py
│ │ └── openai_dalle.py
│ ├── jobs/
│ │ ├── __init__.py
│ │ ├── agent_executor.py
│ │ └── scheduling_executor.py
│ ├── lib/
│ │ └── logger.py
│ ├── llms/
│ │ ├── __init__.py
│ │ ├── base_llm.py
│ │ ├── google_palm.py
│ │ ├── grammar/
│ │ │ └── json.gbnf
│ │ ├── hugging_face.py
│ │ ├── llm_model_factory.py
│ │ ├── local_llm.py
│ │ ├── openai.py
│ │ ├── replicate.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ └── huggingface_utils/
│ │ ├── __init__.py
│ │ ├── public_endpoints.py
│ │ └── tasks.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── agent.py
│ │ ├── agent_config.py
│ │ ├── agent_execution.py
│ │ ├── agent_execution_config.py
│ │ ├── agent_execution_feed.py
│ │ ├── agent_execution_permission.py
│ │ ├── agent_schedule.py
│ │ ├── agent_template.py
│ │ ├── agent_template_config.py
│ │ ├── api_key.py
│ │ ├── base_model.py
│ │ ├── budget.py
│ │ ├── call_logs.py
│ │ ├── configuration.py
│ │ ├── db.py
│ │ ├── events.py
│ │ ├── knowledge_configs.py
│ │ ├── knowledges.py
│ │ ├── marketplace_stats.py
│ │ ├── models.py
│ │ ├── models_config.py
│ │ ├── oauth_tokens.py
│ │ ├── organisation.py
│ │ ├── project.py
│ │ ├── resource.py
│ │ ├── tool.py
│ │ ├── tool_config.py
│ │ ├── toolkit.py
│ │ ├── types/
│ │ │ ├── __init__.py
│ │ │ ├── agent_config.py
│ │ │ ├── login_request.py
│ │ │ └── validate_llm_api_key_request.py
│ │ ├── user.py
│ │ ├── vector_db_configs.py
│ │ ├── vector_db_indices.py
│ │ ├── vector_dbs.py
│ │ ├── webhook_events.py
│ │ ├── webhooks.py
│ │ └── workflows/
│ │ ├── __init__.py
│ │ ├── agent_workflow.py
│ │ ├── agent_workflow_step.py
│ │ ├── agent_workflow_step_tool.py
│ │ ├── agent_workflow_step_wait.py
│ │ ├── iteration_workflow.py
│ │ └── iteration_workflow_step.py
│ ├── resource_manager/
│ │ ├── __init__.py
│ │ ├── file_manager.py
│ │ ├── llama_document_summary.py
│ │ ├── llama_vector_store_factory.py
│ │ ├── resource_manager.py
│ │ └── resource_summary.py
│ ├── tool_manager.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── apollo/
│ │ │ ├── __init__.py
│ │ │ ├── apollo_search.py
│ │ │ └── apollo_toolkit.py
│ │ ├── base_tool.py
│ │ ├── code/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── coding_toolkit.py
│ │ │ ├── improve_code.py
│ │ │ ├── prompts/
│ │ │ │ ├── generate_logic.txt
│ │ │ │ ├── improve_code.txt
│ │ │ │ ├── write_code.txt
│ │ │ │ ├── write_spec.txt
│ │ │ │ └── write_test.txt
│ │ │ ├── write_code.py
│ │ │ ├── write_spec.py
│ │ │ └── write_test.py
│ │ ├── duck_duck_go/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── duck_duck_go_search.py
│ │ │ └── duck_duck_go_search_toolkit.py
│ │ ├── email/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── email_toolkit.py
│ │ │ ├── read_email.py
│ │ │ ├── send_email.py
│ │ │ └── send_email_attachment.py
│ │ ├── file/
│ │ │ ├── __init__.py
│ │ │ ├── append_file.py
│ │ │ ├── delete_file.py
│ │ │ ├── file_toolkit.py
│ │ │ ├── list_files.py
│ │ │ ├── read_file.py
│ │ │ └── write_file.py
│ │ ├── github/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── add_file.py
│ │ │ ├── delete_file.py
│ │ │ ├── fetch_pull_request.py
│ │ │ ├── github_toolkit.py
│ │ │ ├── prompts/
│ │ │ │ └── code_review.txt
│ │ │ ├── review_pull_request.py
│ │ │ └── search_repo.py
│ │ ├── google_calendar/
│ │ │ ├── README.md
│ │ │ ├── create_calendar_event.py
│ │ │ ├── delete_calendar_event.py
│ │ │ ├── event_details_calendar.py
│ │ │ ├── google_calendar_toolkit.py
│ │ │ └── list_calendar_events.py
│ │ ├── google_search/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── google_search.py
│ │ │ └── google_search_toolkit.py
│ │ ├── google_serp_search/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── google_serp_search.py
│ │ │ └── google_serp_search_toolkit.py
│ │ ├── image_generation/
│ │ │ ├── README.MD
│ │ │ ├── README.STABLE_DIFFUSION.md
│ │ │ ├── __init__.py
│ │ │ ├── dalle_image_gen.py
│ │ │ ├── image_generation_toolkit.py
│ │ │ └── stable_diffusion_image_gen.py
│ │ ├── instagram_tool/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── instagram.py
│ │ │ └── instagram_toolkit.py
│ │ ├── jira/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── create_issue.py
│ │ │ ├── edit_issue.py
│ │ │ ├── get_projects.py
│ │ │ ├── jira_toolkit.py
│ │ │ ├── search_issues.py
│ │ │ └── tool.py
│ │ ├── knowledge_search/
│ │ │ ├── knowledge_search.py
│ │ │ └── knowledge_search_toolkit.py
│ │ ├── resource/
│ │ │ ├── __init__.py
│ │ │ ├── query_resource.py
│ │ │ └── resource_toolkit.py
│ │ ├── searx/
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── search_scraper.py
│ │ │ ├── searx.py
│ │ │ └── searx_toolkit.py
│ │ ├── slack/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── send_message.py
│ │ │ └── slack_toolkit.py
│ │ ├── thinking/
│ │ │ ├── __init__.py
│ │ │ ├── prompts/
│ │ │ │ └── thinking.txt
│ │ │ ├── thinking_toolkit.py
│ │ │ └── tools.py
│ │ ├── tool_response_query_manager.py
│ │ ├── twitter/
│ │ │ ├── README.md
│ │ │ ├── send_tweets.py
│ │ │ └── twitter_toolkit.py
│ │ └── webscaper/
│ │ ├── README.MD
│ │ ├── __init__.py
│ │ ├── tools.py
│ │ └── web_scraper_toolkit.py
│ ├── types/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── key_type.py
│ │ ├── model_source_types.py
│ │ ├── queue_status.py
│ │ ├── storage_types.py
│ │ └── vector_store_types.py
│ ├── vector_embeddings/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── pinecone.py
│ │ ├── qdrant.py
│ │ ├── vector_embedding_factory.py
│ │ └── weaviate.py
│ ├── vector_store/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── chromadb.py
│ │ ├── document.py
│ │ ├── embedding/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── openai.py
│ │ │ └── palm.py
│ │ ├── pinecone.py
│ │ ├── qdrant.py
│ │ ├── redis.py
│ │ ├── vector_factory.py
│ │ └── weaviate.py
│ └── worker.py
├── test.py
├── test_main.http
├── tests/
│ ├── __init__.py
│ ├── integration_tests/
│ │ ├── __init__.py
│ │ ├── vector_embeddings/
│ │ │ ├── __init__.py
│ │ │ ├── test_pinecone.py
│ │ │ ├── test_qdrant.py
│ │ │ └── test_weaviate.py
│ │ └── vector_store/
│ │ ├── __init__.py
│ │ ├── test_qdrant.py
│ │ └── test_weaviate.py
│ ├── tools/
│ │ └── google_calendar/
│ │ ├── create_event_test.py
│ │ ├── delete_event_test.py
│ │ ├── event_details_test.py
│ │ └── list_events_test.py
│ └── unit_tests/
│ ├── __init__.py
│ ├── agent/
│ │ ├── __init__.py
│ │ ├── test_agent_iteration_step_handler.py
│ │ ├── test_agent_message_builder.py
│ │ ├── test_agent_prompt_builder.py
│ │ ├── test_agent_prompt_template.py
│ │ ├── test_agent_tool_step_handler.py
│ │ ├── test_agent_workflow_step_wait_handler.py
│ │ ├── test_output_handler.py
│ │ ├── test_output_parser.py
│ │ ├── test_queue_step_handler.py
│ │ ├── test_task_queue.py
│ │ ├── test_tool_builder.py
│ │ └── test_tool_executor.py
│ ├── apm/
│ │ ├── __init__.py
│ │ ├── test_analytics_helper.py
│ │ ├── test_call_log_helper.py
│ │ ├── test_event_handler.py
│ │ ├── test_knowledge_handler.py
│ │ └── test_tools_handler.py
│ ├── controllers/
│ │ ├── __init__.py
│ │ ├── api/
│ │ │ ├── __init__.py
│ │ │ └── test_agent.py
│ │ ├── test_agent.py
│ │ ├── test_agent_execution.py
│ │ ├── test_agent_execution_config.py
│ │ ├── test_agent_execution_feeds.py
│ │ ├── test_agent_template.py
│ │ ├── test_analytics.py
│ │ ├── test_models_controller.py
│ │ ├── test_publish_agent.py
│ │ ├── test_tool.py
│ │ ├── test_tool_config.py
│ │ ├── test_toolkit.py
│ │ ├── test_update_agent_config_table.py
│ │ └── test_user.py
│ ├── helper/
│ │ ├── __init__.py
│ │ ├── test_agent_schedule_helper.py
│ │ ├── test_calendar_date.py
│ │ ├── test_error_handling.py
│ │ ├── test_feed_parser.py
│ │ ├── test_github_helper.py
│ │ ├── test_json_cleaner.py
│ │ ├── test_resource_helper.py
│ │ ├── test_s3_helper.py
│ │ ├── test_time_helper.py
│ │ ├── test_token_counter.py
│ │ ├── test_tool_helper.py
│ │ ├── test_twitter_helper.py
│ │ ├── test_twitter_tokens.py
│ │ └── test_webhooks.py
│ ├── jobs/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_resource_summary.py
│ │ └── test_scheduling_executor.py
│ ├── llms/
│ │ ├── __init__.py
│ │ ├── test_google_palm.py
│ │ ├── test_hugging_face.py
│ │ ├── test_model_factory.py
│ │ ├── test_open_ai.py
│ │ └── test_replicate.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── test_agent.py
│ │ ├── test_agent_execution.py
│ │ ├── test_agent_execution_config.py
│ │ ├── test_agent_execution_feed.py
│ │ ├── test_agent_schedule.py
│ │ ├── test_agent_template.py
│ │ ├── test_agent_workflow.py
│ │ ├── test_agent_workflow_step.py
│ │ ├── test_agent_workflow_step_tool.py
│ │ ├── test_api_key.py
│ │ ├── test_call_logs.py
│ │ ├── test_configuration.py
│ │ ├── test_events.py
│ │ ├── test_iteration_workflow.py
│ │ ├── test_iteration_workflow_step.py
│ │ ├── test_knowledge_configs.py
│ │ ├── test_marketplace_stats.py
│ │ ├── test_models.py
│ │ ├── test_models_config.py
│ │ ├── test_project.py
│ │ ├── test_tool.py
│ │ ├── test_tool_config.py
│ │ ├── test_toolkit.py
│ │ ├── test_vector_db_configs.py
│ │ ├── test_vector_db_indices.py
│ │ └── test_vector_dbs.py
│ ├── resource_manager/
│ │ ├── __init__.py
│ │ ├── test_file_manager.py
│ │ ├── test_llama_document_creation.py
│ │ ├── test_llama_vector_store_factory.py
│ │ └── test_save_document_to_vector_store.py
│ ├── test_migrations_multiheads.py
│ ├── test_tool_manager.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── code/
│ │ │ ├── __init__.py
│ │ │ ├── test_improve_code.py
│ │ │ ├── test_write_code.py
│ │ │ ├── test_write_spec.py
│ │ │ └── test_write_test.py
│ │ ├── duck_duck_go/
│ │ │ ├── __init__.py
│ │ │ ├── test_duckduckgo_results.py
│ │ │ └── test_duckduckgo_toolkit.py
│ │ ├── email/
│ │ │ ├── __init__.py
│ │ │ ├── test_read_email.py
│ │ │ ├── test_send_email.py
│ │ │ └── test_send_email_attachment.py
│ │ ├── file/
│ │ │ ├── __init__.py
│ │ │ ├── test_list_files.py
│ │ │ └── test_read_file.py
│ │ ├── github/
│ │ │ ├── __init__.py
│ │ │ ├── test_add_file.py
│ │ │ ├── test_fetch_pull_request.py
│ │ │ ├── test_github_delete.py
│ │ │ └── test_review_pull_request.py
│ │ ├── image_generation/
│ │ │ ├── __init__.py
│ │ │ ├── test_dalle_image_gen.py
│ │ │ └── test_stable_diffusion_image_gen.py
│ │ ├── instagram_tool/
│ │ │ ├── __init__.py
│ │ │ ├── test_instagram_tool.py
│ │ │ └── test_instagram_toolkit.py
│ │ ├── jira/
│ │ │ ├── __init__.py
│ │ │ ├── test_create_issue.py
│ │ │ ├── test_edit_issue.py
│ │ │ ├── test_get_projects.py
│ │ │ └── test_search_issues.py
│ │ ├── knowledge_tool/
│ │ │ ├── __init__.py
│ │ │ └── test_knowledge_search.py
│ │ ├── searx/
│ │ │ ├── __init__.py
│ │ │ └── test_searx_toolkit.py
│ │ ├── test_search_repo.py
│ │ └── twitter/
│ │ └── test_send_tweets.py
│ ├── types/
│ │ ├── __init__.py
│ │ └── test_model_source_types.py
│ ├── vector_embeddings/
│ │ ├── __init__.py
│ │ └── test_vector_embedding_factory.py
│ └── vector_store/
│ ├── __init__.py
│ ├── test_chromadb.py
│ ├── test_redis.py
│ └── test_vector_factory.py
├── tgwui/
│ ├── DockerfileTGWUI
│ ├── config/
│ │ ├── loras/
│ │ │ └── place-your-loras-here.txt
│ │ ├── presets/
│ │ │ ├── Debug-deterministic.yaml
│ │ │ ├── Kobold-Godlike.yaml
│ │ │ ├── Kobold-Liminal Drift.yaml
│ │ │ ├── LLaMA-Precise.yaml
│ │ │ ├── Naive.yaml
│ │ │ ├── NovelAI-Best Guess.yaml
│ │ │ ├── NovelAI-Decadence.yaml
│ │ │ ├── NovelAI-Genesis.yaml
│ │ │ ├── NovelAI-Lycaenidae.yaml
│ │ │ ├── NovelAI-Ouroboros.yaml
│ │ │ ├── NovelAI-Pleasing Results.yaml
│ │ │ ├── NovelAI-Sphinx Moth.yaml
│ │ │ ├── NovelAI-Storywriter.yaml
│ │ │ ├── Special-Contrastive Search.yaml
│ │ │ └── Special-Eta Sampling.yaml
│ │ ├── prompts/
│ │ │ ├── Alpaca-with-Input.txt
│ │ │ ├── GPT-4chan.txt
│ │ │ └── QA.txt
│ │ └── training/
│ │ ├── datasets/
│ │ │ └── put-trainer-datasets-here.txt
│ │ └── formats/
│ │ ├── alpaca-chatbot-format.json
│ │ └── alpaca-format.json
│ └── scripts/
│ ├── build_extensions.sh
│ └── docker-entrypoint.sh
├── ui.py
└── wait-for-it.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .do/app.yaml
================================================
alerts:
- rule: DEPLOYMENT_FAILED
- rule: DOMAIN_FAILED
databases:
- engine: PG
name: super-agi-main
num_nodes: 1
size: basic-xs
version: "12"
ingress:
rules:
- component:
name: superagi-backend
match:
path:
prefix: /api
name: superagi
services:
- dockerfile_path: DockerfileRedis
github:
branch: main
deploy_on_push: true
repo: TransformerOptimus/SuperAGI
internal_ports:
- 6379
instance_count: 1
instance_size_slug: basic-xs
source_dir: /
name: superagi-redis
- dockerfile_path: Dockerfile
envs:
- key: REDIS_URL
scope: RUN_TIME
value: superagi-redis:6379
- key: DB_URL
scope: RUN_TIME
value: ${super-agi-main.DATABASE_URL}
github:
branch: main
deploy_on_push: true
repo: TransformerOptimus/SuperAGI
http_port: 8001
instance_count: 1
instance_size_slug: basic-xs
run_command: /app/entrypoint.sh
source_dir: /
name: superagi-backend
- dockerfile_path: ./gui/DockerfileProd
github:
branch: main
deploy_on_push: true
repo: TransformerOptimus/SuperAGI
http_port: 3000
instance_count: 1
instance_size_slug: basic-xs
source_dir: ./gui
name: superagi-gui
workers:
- dockerfile_path: Dockerfile
envs:
- key: REDIS_URL
scope: RUN_TIME
value: superagi-redis:6379
- key: DB_URL
scope: RUN_TIME
value: ${super-agi-main.DATABASE_URL}
github:
branch: main
deploy_on_push: true
repo: TransformerOptimus/SuperAGI
instance_count: 1
instance_size_slug: basic-xs
run_command: celery -A superagi.worker worker --beat --loglevel=info
source_dir: /
name: superagi-celery
================================================
FILE: .do/deploy.template.yaml
================================================
spec:
alerts:
- rule: DEPLOYMENT_FAILED
- rule: DOMAIN_FAILED
databases:
- engine: PG
name: super-agi-main
num_nodes: 1
size: basic-xs
version: "12"
ingress:
rules:
- component:
name: superagi-backend
match:
path:
prefix: /api
name: superagi
services:
- dockerfile_path: DockerfileRedis
git:
branch: main
repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git
internal_ports:
- 6379
instance_count: 1
instance_size_slug: basic-xs
source_dir: /
name: superagi-redis
- dockerfile_path: Dockerfile
envs:
- key: REDIS_URL
scope: RUN_TIME
value: superagi-redis:6379
- key: DB_URL
scope: RUN_TIME
value: ${super-agi-main.DATABASE_URL}
git:
branch: main
repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git
http_port: 8001
instance_count: 1
instance_size_slug: basic-xs
run_command: /app/entrypoint.sh
source_dir: /
name: superagi-backend
- dockerfile_path: ./gui/DockerfileProd
git:
branch: main
repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git
http_port: 3000
instance_count: 1
instance_size_slug: basic-xs
source_dir: ./gui
name: superagi-gui
workers:
- dockerfile_path: Dockerfile
envs:
- key: REDIS_URL
scope: RUN_TIME
value: superagi-redis:6379
- key: DB_URL
scope: RUN_TIME
value: ${super-agi-main.DATABASE_URL}
git:
branch: main
repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git
instance_count: 1
instance_size_slug: basic-xs
run_command: celery -A superagi.worker worker --beat --loglevel=info
source_dir: /
name: superagi-celery
================================================
FILE: .dockerignore
================================================
# Ignore everything
**
# Allow files and directories
!/migrations
!/nginx
!/superagi
!/tgwui
!/tools
!/workspace
!/main.py
!/requirements.txt
!/entrypoint.sh
!/entrypoint_celery.sh
!/wait-for-it.sh
!/tools.json
!/install_tool_dependencies.sh
!/alembic.ini
================================================
FILE: .gitattributes
================================================
*.sh text eol=lf
================================================
FILE: .github/ISSUE_TEMPLATE/1.BUG_REPORT.yml
================================================
name: Bug report
description: Create a bug report for SuperAGI.
labels: ['status: needs triage']
body:
- type: markdown
attributes:
value: |
### ⚠️ Issue Creation Guideline
* Check out our [roadmap] and join our [discord] to discuss what's going on
* If you need help, you can ask in the [#general] section or in [#support]
* **Thoroughly search the [existing issues] before creating a new one**
* Read through our docs:
[roadmap]: https://github.com/users/TransformerOptimus/projects/5
[discord]: https://discord.gg/dXbRe5BHJC
[#general]: https://discord.com/channels/1107593006032355359/1107642413993959505
[#support]: https://discord.com/channels/1107593006032355359/1107645922797703198
[existing issues]: https://github.com/TransformerOptimus/SuperAGI/issues
- type: checkboxes
attributes:
label: ⚠️ Check for existing issues before proceeding. ⚠️
description: >
Please [search the history](https://github.com/TransformerOptimus/SuperAGI/issues)
to see if an issue already exists for the same problem.
options:
- label: I have searched the existing issues, and there is no existing issue for my problem
required: true
- type: markdown
attributes:
value: |
Please confirm that the issue you have is described well and precise in the title above ⬆️.
Think like this: What would you type if you were searching for the issue?
For example:
❌ - my SuperAGI agent keeps looping
✅ - After performing Write Tool, SuperAGI goes into a loop where it keeps trying to write the file.
Please help us help you by following these steps:
- Search for existing issues, adding a comment when you have the same or similar issue is tidier than "new issue" and
newer issues will not be reviewed earlier, this is dependent on the current priorities set by our wonderful team
- Ask on our Discord if your issue is known when you are unsure (https://discord.gg/dXbRe5BHJC)
- Provide relevant info:
- Provide Docker Logs(docker compose logs) whenever possible.
- If it's a pip/packages issue, mention this in the title and provide pip version, python version.
- type: dropdown
attributes:
label: Where are you using SuperAGI?
description: >
Please select the operating system you were using to run SuperAGI when this problem occurred.
options:
- Windows
- Linux
- MacOS
- Codespaces
- Web Version
- Other
validations:
required: true
nested_fields:
- type: text
attributes:
label: Specify the system
description: Please specify the system you are working on.
- type: dropdown
attributes:
label: Which branch of SuperAGI are you using?
description: |
Please select which version of SuperAGI you were using when this issue occurred.
If installed with git you can run `git branch` to see which version of Auto-GPT you are running.
options:
- Main
- Dev (branch)
validations:
required: true
- type: dropdown
attributes:
label: Do you use OpenAI GPT-3.5 or GPT-4?
description: >
If you are using SuperAGI with GPT-3.5, your problems may be caused by
the limitations of GPT-3.5 like incorrect Tool selection thus causing looping in the agent feed.
options:
- GPT-3.5
- GPT-3.5(16k)
- GPT-4
- GPT-4(32k)
validations:
required: true
- type: dropdown
attributes:
label: Which area covers your issue best?
description: >
Select the area related to the issue you are reporting.
options:
- Installation and setup
- Resource Manager
- Action Console
- Performance
- Marketplace
- Prompt
- Tools
- Agents
- Documentation
- Logging
- Other
validations:
required: true
autolabels: true
nested_fields:
- type: text
attributes:
label: Specify the area
description: Please specify the area you think is best related to the issue.
- type: textarea
attributes:
label: Describe your issue.
description: Describe the problem you are experiencing. Try to describe only the issue and phrase it short but clear. ⚠️ Provide NO other data in this field
validations:
required: true
- type: textarea
attributes:
label: How to replicate your Issue?
description: |
Mention Agent Name, Agent Description and Agent Goals, along with Model selected.
Provide any other data which might be relevant for us to replicate this issue.
⚠️ Provide NO other data in this field
validations:
required: false
- type: markdown
attributes:
value: |
⚠️ Please keep in mind that the log files may contain personal information such as credentials. Make sure you hide them before copy/pasting it! ⚠️
- type: input
attributes:
label: Upload Error Log Content
description: |
Upload the error log content, this can help us understand the issue better.
To do this, you can simply copy the logs from the terminal with which you did 'docker compose up' or in a new terminal,
enter 'docker compose logs' and copy/paste the error contents to this field.
⚠️ The activity log may contain personal data given to SuperAGI by you in prompt or input as well as
any personal information that SuperAGI collected out of files during last run. Please hide them before sharing. ⚠️
validations:
required: True
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
### Description
### Related Issues
### Solution and Design
### Test Plan
### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs update
### Checklist
- [ ] My pull request is atomic and focuses on a single change.
- [ ] I have read the contributing guide and my code conforms to the guidelines.
- [ ] I have documented my changes clearly and comprehensively.
- [ ] I have added the required tests.
================================================
FILE: .github/workflows/ci.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Python CI
on:
push:
branches: [ "main", "dev" ]
pull_request:
branches: [ "main", "dev" ]
permissions:
contents: read
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Cache Python dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
test:
permissions:
# Gives the action the necessary permissions for publishing new
# comments in pull requests.
pull-requests: write
# Gives the action the necessary permissions for pushing data to the
# python-coverage-comment-action branch, and for editing existing
# comments (to avoid publishing multiple comments in the same PR)
contents: write
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
matrix:
python-version: ["3.9"]
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
submodules: true
- name: Configure git user SuperAGI-Bot
run: |
git config --global user.name "SuperAGI-Bot"
git config --global user.email "github-bot@superagi.com"
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Cache dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest --cov=superagi --cov-branch --cov-report term-missing --cov-report xml \
tests/unit_tests -s
env:
CI: true
ENV: DEV
PLAIN_OUTPUT: True
REDIS_URL: "localhost:6379"
IS_TESTING: True
ENCRYPTION_KEY: "abcdefghijklmnopqrstuvwxyz123456"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
================================================
FILE: .github/workflows/codeql.yml
================================================
name: "CodeQL"
on:
push:
branches: [ 'main', 'dev' ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ 'main' ]
schedule:
- cron: '48 0 * * 2'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'javascript', 'python' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
# Use only 'java' to analyze code written in Java, Kotlin or both
# Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
# - run: |
# echo "Run, Build Application using script"
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"
================================================
FILE: .gitignore
================================================
.idea
**/.env
**/.venv
config.yaml
__pycache__
superagi/models/__pycache__
superagi/controllers/__pycache__
**agent_dictvenv
**/__gitpycache__/
gui/node_modules
node_modules
gui/.next
.DS_Store
.DS_Store?
venv
workspace/output
workspace/input
celerybeat-schedule
../bfg-report*
superagi/tools/marketplace_tools/
superagi/tools/external_tools/
tests/unit_tests/resource_manager/test_path
/tools.json
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: CONTRIBUTING.md
================================================
# ⚡ Contributing to SuperAGI
First of all, thank you for taking the time to contribute to this project. We truly appreciate your contributions, whether it's bug reports, feature suggestions, or pull requests. Your time and effort are highly valued in this project. 🚀
This document provides guidelines and best practices to help you to contribute effectively. These are meant to serve as guidelines, not strict rules. We encourage you to use your best judgment and feel comfortable proposing changes to this document through a pull request.
For all contributions, a CLA (Contributor License Agreement) needs to be signed
[here](https://cla-assistant.io/TransformerOptimus/SuperAGI) before (or after) the pull request has been submitted.
**********************************Table of Content:**********************************
1. [Code of conduct](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#code-of-conduct)
2. [Quick Start](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#quick-start)
3. [Contributing Guidelines](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#contributing-guidelines)
1. [Reporting Bugs](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#reporting-bugs)
2. [New Feature or Suggesting Enhancements](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#new-feature-or-suggesting-enhancements)
4. [Testing](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#testing-changes)
5. [Pull Requests](https://github.com/TransformerOptimus/SuperAGI/blob/CONTRIBUTING.md#pull-requests)
## ✔️ Code of Conduct:
Please read our [Code of Conduct](https://github.com/TransformerOptimus/SuperAGI/blob/main/CODE_OF_CONDUCT.md) to understand the expectations we have for all contributors participating in this project. By participating, you agree to abide by our Code of Conduct.
## 🚀 Quick Start
You can quickly get started with contributing by searching for issues with the labels **"Good First Issue"** or **"Help Needed"** in the [Issues Section](https://github.com/TransformerOptimus/SuperAGI/Issues). If you think you can contribute, comment on the issue and we will assign it to you.
To set up your development environment, please follow the steps mentioned below :
1. Fork the repository and create a clone of the fork
2. Create a branch for a feature or a bug you are working on in your fork
3. Once you've created your branch, follow the instructions in the [README.MD](https://github.com/TransformerOptimus/SuperAGI/README.MD)
## Contributing Guidelines
### 🔍 Reporting Bugs
You can start working on an existing bug that has been reported and labeled as **"Bug"** in the Issues Section, and you can report your bugs in the following manner :
1. Title describing the issue clearly and concisely with relevant labels
2. Provide a detailed description of the problem and the necessary steps to reproduce the issue.
3. Include any relevant logs, screenshots, or other helpful information supporting the issue.
### :bulb: New Feature or Suggesting Enhancements
This section guides you through working on an enhancement **Including a completely New Feature** & **Enhancements to an existing functionality**.
Before getting started, Perform a search on Issues to see if the enhancement or feature has already been suggested and picked up. If the feature or enhancement is suggested and not picked up, comment on the issue and assign yourself to it.
If the feature or enhancement is not in the issues, find out whether your idea fits with the scope and aims of the project by looking at the [Roadmap](https://github.com/users/TransformerOptimus/projects/5/). If yes, raise an issue with the label **"Feature Request"** in the following manner:
1. Title describing the feature or enhancement in a clear and concise manner
2. Clearly describe the proposed enhancement, highlighting its benefits and potential drawbacks.
3. Provide examples and supporting information.
Once you have raised the issue and have gotten yourself assigned, you can start working on the feature or enhancement. Please make sure the feature or enhancement you're working on is placed on the Roadmap.
## Testing your Changes
Each method or the function of the code should have a unit test with the maximum coverage possible and on each Pull Request, we have GitHub Actions triggered, which
runs all the unit tests where all the tests should pass for merging the Pull Request.
## Pull Request
Now that you have worked on your code and tested it thoroughly, you can now go ahead and raise the pull request. Please make sure that the Pull Request adheres to the following guidelines:
1. The pull request is atomic and focuses on a single change.
2. You have read the contributing guide and your code conforms to the guidelines.
3. You have documented your changes clearly and comprehensively.
4. You have added the required tests.
================================================
FILE: Dockerfile
================================================
# Stage 1: Compile image
FROM python:3.10-slim-bullseye AS compile-image
WORKDIR /app
RUN apt-get update && \
apt-get install --no-install-recommends -y wget libpq-dev gcc g++ && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
COPY requirements.txt .
RUN pip install --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
RUN python3.10 -c "import nltk; nltk.download('punkt')" && \
python3.10 -c "import nltk; nltk.download('averaged_perceptron_tagger')"
COPY . .
RUN chmod +x ./entrypoint.sh ./wait-for-it.sh ./install_tool_dependencies.sh ./entrypoint_celery.sh
# Stage 2: Build image
FROM python:3.10-slim-bullseye AS build-image
WORKDIR /app
RUN apt-get update && \
apt-get install --no-install-recommends -y libpq-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
COPY --from=compile-image /opt/venv /opt/venv
COPY --from=compile-image /app /app
COPY --from=compile-image /root/nltk_data /root/nltk_data
ENV PATH="/opt/venv/bin:$PATH"
EXPOSE 8001
================================================
FILE: Dockerfile-gpu
================================================
# Define the CUDA SDK version you need
ARG CUDA_IMAGE="12.1.1-devel-ubuntu22.04"
FROM nvidia/cuda:${CUDA_IMAGE}
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && apt-get upgrade -y \
&& apt-get install -y git build-essential \
python3 python3-pip python3.10-venv libpq-dev gcc wget \
ocl-icd-opencl-dev opencl-headers clinfo \
libclblast-dev libopenblas-dev \
&& mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
# Create a virtual environment and activate it
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Install Python dependencies from requirements.txt
COPY requirements.txt .
RUN pip install --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
# Running nltk setup as you mentioned
RUN python3.10 -c "import nltk; nltk.download('punkt')" && \
python3.10 -c "import nltk; nltk.download('averaged_perceptron_tagger')"
# Copy the application code
COPY . .
ENV CUDA_DOCKER_ARCH=all
ENV LLAMA_CUBLAS=1
RUN CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python==0.2.7 --force-reinstall --upgrade --no-cache-dir
# Make necessary scripts executable
RUN chmod +x ./entrypoint.sh ./wait-for-it.sh ./install_tool_dependencies.sh ./entrypoint_celery.sh
# Set environment variable to point to the custom libllama.so
# ENV LLAMA_CPP_LIB=/app/llama.cpp/libllama.so
EXPOSE 8001
CMD ["./entrypoint.sh"]
================================================
FILE: DockerfileCelery
================================================
FROM python:3.9
WORKDIR /app
#RUN apt-get update && apt-get install --no-install-recommends -y git wget libpq-dev gcc python3-dev && pip install psycopg2
RUN pip install --upgrade pip
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
# Downloads the tools
RUN python superagi/tool_manager.py
# Set executable permissions for install_tool_dependencies.sh
RUN chmod +x install_tool_dependencies.sh
# Install dependencies
RUN ./install_tool_dependencies.sh
# Downloads the tools
RUN python superagi/tool_manager.py
# Set executable permissions for install_tool_dependencies.sh
RUN chmod +x install_tool_dependencies.sh
# Install dependencies
RUN ./install_tool_dependencies.sh
CMD ["celery", "-A", "superagi.worker", "worker", "--beat","--loglevel=info"]
================================================
FILE: DockerfileRedis
================================================
FROM redis/redis-stack-server:latest
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 TransformerOptimus
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.MD
================================================
Open-source framework to build, manage and run useful Autonomous AI Agents
Follow SuperAGI
Connect with the Creator
Share SuperAGI Repository
## What are we ?
A dev-first open source autonomous AI agent framework enabling developers to build, manage & run useful autonomous agents. You can run concurrent agents seamlessly, extend agent capabilities with tools. The agents efficiently perform a variety of tasks and continually improve their performance with each subsequent run.
### 💡 Features
- Provision, Spawn & Deploy Autonomous AI Agents - Create production-ready & scalable autonomous agents.
- Extend Agent Capabilities with Toolkits - Add Toolkits from our marketplace to your agent workflows.
- Graphical User Interface - Access your agents through a graphical user interface.
- Action Console - Interact with agents by giving them input and permissions.
- Multiple Vector DBs - Connect to multiple Vector DBs to enhance your agent’s performance.
- Performance Telemetry - Get insights into your agent’s performance and optimize accordingly.
- Optimized Token Usage - Control token usage to manage costs effectively.
- Agent Memory Storage - Enable your agents to learn and adapt by storing their memory.
- Models - Custom fine tuned models for business specific usecases.
- Workflows - Automate tasks with ease using ReAct LLM's predefined steps.
### 🛠 Toolkits
Toolkits allow SuperAGI Agents to interact with external systems and third-party plugins.
### ⚙️ Installation
You can install superAGI using one of the following three approaches.
#### ☁️ SuperAGI cloud
To quickly start experimenting with agents without the hassle of setting up the system, try [Superagi Cloud](https://app.superagi.com/)
1. Visit [Superagi Cloud](https://app.superagi.com/) and log in using your github account.
2. In your account settings, go to "Model Providers" and add your API key.
You're all set! Start running your agents effortlessly.
#### 🖥️ Local
1. Open your terminal and clone the SuperAGI repository.
```
git clone https://github.com/TransformerOptimus/SuperAGI.git
```
2. Navigate to the cloned repository directory using the command:
```
cd SuperAGI
```
3. Create a copy of config_template.yaml, and name it config.yaml.
4. Ensure that Docker is installed on your system. You can download and install it from [here](https://docs.docker.com/get-docker/).
5. Once you have Docker Desktop running, run the following command in the SuperAGI directory:
a. For regular usage:
```
docker compose -f docker-compose.yaml up --build
```
b. If you want to use SuperAGI with Local LLMs and have GPU, run the following command:
```
docker compose -f docker-compose-gpu.yml up --build
```
6. Open your web browser and navigate to http://localhost:3000 to access SuperAGI.
#### 🌀 Digital Ocean
Deploy SuperAGI to DigitalOcean with one click.
### 🌐 Architecture
SuperAGI Architecture

Agent Architecture

Agent Workflow Architecture

Tools Architecture

ER Diagram

### 📚 Resources
* [Documentation](https://superagi.com/docs/)
* [YouTube Channel](https://www.youtube.com/@_SuperAGI/videos)
### 📖 Need Help?
Join our [Discord community](https://discord.gg/dXbRe5BHJC) for support and discussions.
[](https://discord.gg/uJ3XUGsY2R)
If you have questions or encounter issues, please don't hesitate to [create a new issue](https://github.com/TransformerOptimus/SuperAGI/issues/new/choose) to get support.
### 💻 Contribution
We ❤️ our contributors. We’re committed to fostering an open, welcoming, and safe environment in the community.
If you'd like to contribute, start by reading our [Contribution Guide](https://github.com/TransformerOptimus/SuperAGI/blob/main/CONTRIBUTING.md).
We expect everyone participating in the community to abide by our [Code of Conduct](https://github.com/TransformerOptimus/SuperAGI/blob/main/CODE_OF_CONDUCT.md).
To get more idea on where we are heading, checkout our roadmap [here](https://github.com/users/TransformerOptimus/projects/5/views/1).
Explore some [good first issues](https://github.com/TransformerOptimus/SuperAGI/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to start contributing.
### 👩💻 Contributors
[](https://github.com/TransformerOptimus) [](https://github.com/Cptsnowcrasher) [](https://github.com/vectorcrow) [](https://github.com/Akki-jain) [](https://github.com/Autocop-Agent)[](https://github.com/COLONAYUSH)[](https://github.com/luciferlinx101)[](https://github.com/mukundans89)[](https://github.com/Fluder-Paradyne)[](https://github.com/nborthy)[](https://github.com/nihirr)[](https://github.com/Tarraann)[](https://github.com/neelayan7)[](https://github.com/Arkajit-Datta)[](https://github.com/guangchen811)[](https://github.com/juanfpo96)[](https://github.com/iskandarreza)[](https://github.com/jpenalbae)[](https://github.com/pallasite99)[](https://github.com/xutpuu)[](https://github.com/alexkreidler)[](https://github.com/hanhyalex123)[](https://github.com/ps4vs)[](https://github.com/eltociear)
[](https://github.com/shaiss)
[](https://github.com/AdityaRajSingh1992)
[](https://github.com/namansleeps22)
[](https://github.com/sirajperson)
[](https://github.com/hsm207)
[](https://github.com/unkn-wn)
[](https://github.com/DMTarmey)
[](https://github.com/Parth2506)
[](https://github.com/platinaCoder)
[](https://github.com/anisha1607)
[](https://github.com/jorgectf)
[](https://github.com/PaulRBerg)
[](https://github.com/boundless-asura)
[](https://github.com/JPDucky)
[](https://github.com/Vibhusha22)
[](https://github.com/ai-akuma)
[](https://github.com/rounak610)
[](https://github.com/AdarshJha619)
[](https://github.com/ResoluteStoic)
[](https://github.com/JohnHunt999)
[](https://github.com/Maverick-F359)
[](https://github.com/jorgectf)
[](https://github.com/AdityaSharma13064)
[](https://github.com/lalitlj)
[](https://github.com/andrew-kelly-neutralaiz)
[](https://github.com/sayan1101)
### ⚠️ Under Development!
This project is under active development and may still have issues. We appreciate your understanding and patience. If you encounter any problems, please check the open issues first. If your issue is not listed, kindly create a new issue detailing the error or problem you experienced. Thank you for your support!
================================================
FILE: alembic.ini
================================================
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = postgresql://superagi:password@super__postgres:5432/super_agi_main
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
================================================
FILE: cli2.py
================================================
import os
import sys
import subprocess
from time import sleep
import shutil
from sys import platform
from multiprocessing import Process
from superagi.lib.logger import logger
def check_command(command, message):
if not shutil.which(command):
logger.info(message)
sys.exit(1)
def run_npm_commands(shell=False):
os.chdir("gui")
try:
subprocess.run(["npm", "install"], check=True, shell=shell)
except subprocess.CalledProcessError:
logger.error(f"Error during '{' '.join(sys.exc_info()[1].cmd)}'. Exiting.")
sys.exit(1)
os.chdir("..")
def run_server(shell=False,a_name=None,a_description=None,goals=None):
tgwui_process = Process(target=subprocess.run, args=(["python", "test.py","--name",a_name,"--description",a_description,"--goals"]+goals,), kwargs={"shell": shell})
api_process = Process(target=subprocess.run, args=(["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"],), kwargs={"shell": shell})
celery_process = Process(target=subprocess.run, args=(["celery", "-A", "celery_app", "worker", "--loglevel=info"],), kwargs={"shell": shell})
ui_process = Process(target=subprocess.run, args=(["python", "test.py","--name",a_name,"--description",a_description,"--goals"]+goals,), kwargs={"shell": shell})
api_process.start()
celery_process.start()
ui_process.start()
return api_process, ui_process, celery_process
def cleanup(api_process, ui_process, celery_process):
logger.info("Shutting down processes...")
api_process.terminate()
ui_process.terminate()
celery_process.terminate()
logger.info("Processes terminated. Exiting.")
sys.exit(1)
if __name__ == "__main__":
check_command("node", "Node.js is not installed. Please install it and try again.")
check_command("npm", "npm is not installed. Please install npm to proceed.")
check_command("uvicorn", "uvicorn is not installed. Please install uvicorn to proceed.")
agent_name = input("Enter an agent name: ")
agent_description = input("Enter an agent description: ")
goals = []
while True:
goal = input("Enter a goal (or 'q' to quit): ")
if goal == 'q':
break
goals.append(goal)
isWindows = False
if platform == "win32" or platform == "cygwin":
isWindows = True
run_npm_commands(shell=isWindows)
try:
api_process, ui_process, celery_process = run_server(isWindows, agent_name, agent_description, goals)
while True:
try:
sleep(30)
except KeyboardInterrupt:
cleanup(api_process, ui_process, celery_process)
except Exception as e:
cleanup(api_process, ui_process, celery_process)
================================================
FILE: config_template.yaml
================================================
#####################------------------SYSTEM KEYS-------------------------########################
PINECONE_API_KEY: YOUR_PINECONE_API_KEY
PINECONE_ENVIRONMENT: YOUR_PINECONE_ENVIRONMENT
OPENAI_API_KEY: YOUR_OPEN_API_KEY
PALM_API_KEY: YOUR_PALM_API_KEY
REPLICATE_API_TOKEN: YOUR_REPLICATE_API_TOKEN
HUGGING_API_TOKEN: YOUR_HUGGING_FACE_API_TOKEN
# For locally hosted LLMs comment out the next line and uncomment the one after
# to configure a local llm point your browser to 127.0.0.1:7860 and click on the model tab in text generation web ui.
OPENAI_API_BASE: https://api.openai.com/v1
#OPENAI_API_BASE: "http://super__tgwui:5001/v1"
# "gpt-3.5-turbo-0301": 4032, "gpt-4-0314": 8092, "gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-4-32k": 32768, "gpt-4-32k-0314": 32768, "llama":2048, "mpt-7b-storywriter":45000
MODEL_NAME: "gpt-3.5-turbo-0301"
# "gpt-3.5-turbo", , "gpt-4", "models/chat-bison-001"
RESOURCES_SUMMARY_MODEL_NAME: "gpt-3.5-turbo"
MAX_TOOL_TOKEN_LIMIT: 800
MAX_MODEL_TOKEN_LIMIT: 4032 # set to 2048 for llama
#DATABASE INFO
# redis details
DB_NAME: super_agi_main
DB_HOST: super__postgres
DB_USERNAME: superagi
DB_PASSWORD: password
DB_URL: postgresql://superagi:password@super__postgres:5432/super_agi_main
REDIS_URL: "super__redis:6379"
#STORAGE TYPE ("FILE" or "S3")
STORAGE_TYPE: "FILE"
#TOOLS
TOOLS_DIR: "superagi/tools"
#STORAGE INFO FOR FILES
RESOURCES_INPUT_ROOT_DIR: workspace/input/{agent_id}
RESOURCES_OUTPUT_ROOT_DIR: workspace/output/{agent_id}/{agent_execution_id} # For keeping resources at agent execution level
#RESOURCES_OUTPUT_ROOT_DIR: workspace/output/{agent_id} # For keeping resources at agent level
#S3 RELATED DETAILS ONLY WHEN STORAGE_TYPE IS "S3"
BUCKET_NAME:
INSTAGRAM_TOOL_BUCKET_NAME: #Public read bucket, Images generated by stable diffusion are put in this bucket and the public url of the same is generated.
AWS_ACCESS_KEY_ID:
AWS_SECRET_ACCESS_KEY:
#AUTH
ENV: 'DEV' #DEV,PROD, to use GITHUB OAUTH set to PROD
JWT_SECRET_KEY: 'secret'
expiry_time_hours: 1
#GITHUB OAUTH:
GITHUB_CLIENT_ID:
GITHUB_CLIENT_SECRET:
FRONTEND_URL: "http://localhost:3000"
#ENCRYPTION KEY, Replace this with your own key for production
ENCRYPTION_KEY: abcdefghijklmnopqrstuvwxyz123456
#WEAVIATE
# If you are using docker or web hosted uncomment the next two lines and comment the third one
# WEAVIATE_URL: YOUR_WEAVIATE_URL
# WEAVIATE_API_KEY: YOUR_WEAVIATE_API_KEY
WEAVIATE_USE_EMBEDDED: true
#####################------------------TOOLS KEY-------------------------########################
#If you have google api key and CSE key, use this
GOOGLE_API_KEY: YOUR_GOOGLE_API_KEY
SEARCH_ENGINE_ID: YOUR_SEARCH_ENIGNE_ID
# IF YOU DONT HAVE GOOGLE SEARCH KEY, YOU CAN USE SERPER.DEV KEYS
SERP_API_KEY: YOUR_SERPER_API_KEY
#ENTER YOUR EMAIL CREDENTIALS TO ACCESS EMAIL TOOL
EMAIL_ADDRESS: YOUR_EMAIL_ADDRESS
EMAIL_PASSWORD: YOUR_EMAIL_APP_PASSWORD #get the app password from (https://myaccount.google.com/apppasswords)
EMAIL_SMTP_HOST: smtp.gmail.com #Change the SMTP host if not using Gmail
EMAIL_SMTP_PORT: 587 #Change the SMTP port if not using Gmail
EMAIL_IMAP_SERVER: imap.gmail.com #Change the IMAP Host if not using Gmail
EMAIL_SIGNATURE: Email sent by SuperAGI
EMAIL_DRAFT_MODE_WITH_FOLDER: YOUR_DRAFTS_FOLDER
EMAIL_ATTACHMENT_BASE_PATH: YOUR_DIRECTORY_FOR_EMAIL_ATTACHMENTS
# GITHUB
GITHUB_USERNAME: YOUR_GITHUB_USERNAME
GITHUB_ACCESS_TOKEN: YOUR_GITHUB_ACCESS_TOKEN
#JIRA
JIRA_INSTANCE_URL: YOUR_JIRA_INSTANCE_URL
JIRA_USERNAME: YOUR_JIRA_EMAIL
JIRA_API_TOKEN: YOUR_JIRA_API_TOKEN
#SLACK
SLACK_BOT_TOKEN: YOUR_SLACK_BOT_TOKEN
# For running stable diffusion
STABILITY_API_KEY: YOUR_STABILITY_API_KEY
#Engine IDs that can be used: 'stable-diffusion-v1', 'stable-diffusion-v1-5','stable-diffusion-512-v2-0', 'stable-diffusion-768-v2-0','stable-diffusion-512-v2-1','stable-diffusion-768-v2-1','stable-diffusion-xl-beta-v2-2-2'
ENGINE_ID: "stable-diffusion-xl-beta-v2-2-2"
## To config a vector store for resources manager uncomment config below
## based on the vector store you want to use
## RESOURCE_VECTOR_STORE can be REDIS, PINECONE, CHROMA, QDRANT
#RESOURCE_VECTOR_STORE: YOUR_RESOURCE_VECTOR_STORE
#RESOURCE_VECTOR_STORE_INDEX_NAME: YOUR_RESOURCE_VECTOR_STORE_INDEX_NAME
## To use a custom redis
#REDIS_VECTOR_STORE_URL: YOUR_REDIS_VECTOR_STORE_URL
## To use qdrant for vector store in resources manager
#QDRANT_PORT: YOUR_QDRANT_PORT
#QDRANT_HOST_NAME: YOUR_QDRANT_HOST_NAME
## To use chroma for vector store in resources manager
#CHROMA_HOST_NAME: YOUR_CHROMA_HOST_NAME
#CHROMA_PORT: YOUR_CHROMA_PORT
## To use Qdrant for vector store
#QDRANT_HOST_NAME: YOUR_QDRANT_HOST_NAME
#QDRANT_PORT: YOUR_QDRANT_PORT
#GPU_LAYERS: GPU LAYERS THAT YOU WANT TO OFFLOAD TO THE GPU WHILE USING LOCAL LLMS
================================================
FILE: docker-compose-dev.yaml
================================================
version: '3.8'
services:
backend:
volumes:
- "./:/app"
build: .
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/wait-for-it.sh", "super__postgres:5432","-t","60","--","/app/entrypoint.sh"]
celery:
volumes:
- "./:/app"
- "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext"
build: .
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/entrypoint_celery.sh"]
gui:
build:
context: ./gui
args:
NEXT_PUBLIC_API_BASE_URL: "/api"
networks:
- super_network
# volumes:
# - ./gui:/app
# - /app/node_modules/
# - /app/.next/
super__redis:
image: "redis/redis-stack-server:latest"
networks:
- super_network
# uncomment to expose redis port to host
# ports:
# - "6379:6379"
volumes:
- redis_data:/data
super__postgres:
image: "docker.io/library/postgres:latest"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
# uncomment to expose postgres port to host
# ports:
# - "5432:5432"
proxy:
image: nginx:stable-alpine
ports:
- "3000:80"
networks:
- super_network
depends_on:
- backend
- gui
volumes:
- ./nginx/default.conf:/etc/nginx/conf.d/default.conf
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
redis_data:
================================================
FILE: docker-compose-gpu.yml
================================================
version: '3.8'
services:
backend:
volumes:
- "./:/app"
build:
context: .
dockerfile: Dockerfile-gpu
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/wait-for-it.sh", "super__postgres:5432","-t","60","--","/app/entrypoint.sh"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
celery:
volumes:
- "./:/app"
- "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext"
build:
context: .
dockerfile: Dockerfile-gpu
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/entrypoint_celery.sh"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
gui:
build:
context: ./gui
args:
NEXT_PUBLIC_API_BASE_URL: "/api"
networks:
- super_network
# volumes:
# - ./gui:/app
# - /app/node_modules/
# - /app/.next/
super__redis:
image: "redis/redis-stack-server:latest"
networks:
- super_network
# uncomment to expose redis port to host
# ports:
# - "6379:6379"
volumes:
- redis_data:/data
super__postgres:
image: "docker.io/library/postgres:15"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
# uncomment to expose postgres port to host
# ports:
# - "5432:5432"
proxy:
image: nginx:stable-alpine
ports:
- "3000:80"
networks:
- super_network
depends_on:
- backend
- gui
volumes:
- ./nginx/default.conf:/etc/nginx/conf.d/default.conf
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
redis_data:
================================================
FILE: docker-compose.image.example.yaml
================================================
version: '3.8'
services:
backend:
image: "superagidev/superagi:main"
depends_on:
- super__redis
- super__postgres
networks:
- super_network
env_file:
- config.yaml
command: ["/app/wait-for-it.sh", "super__postgres:5432","-t","60","--","/app/entrypoint.sh"]
celery:
image: "superagidev/superagi:main"
depends_on:
- super__redis
- super__postgres
networks:
- super_network
env_file:
- config.yaml
command: ["/app/entrypoint_celery.sh"]
volumes:
- "./workspace:/app/workspace"
gui:
image: "superagidev/superagi-frontend:main"
environment:
- NEXT_PUBLIC_API_BASE_URL=/api
networks:
- super_network
super__redis:
image: "redis/redis-stack-server:latest"
networks:
- super_network
# uncomment to expose redis port to host
# ports:
# - "6379:6379"
volumes:
- redis_data:/data
super__postgres:
image: "docker.io/library/postgres:latest"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
# uncomment to expose postgres port to host
# ports:
# - "5432:5432"
proxy:
image: nginx:stable-alpine
ports:
- "3000:80"
networks:
- super_network
depends_on:
- backend
- gui
volumes:
- ./nginx/default.conf:/etc/nginx/conf.d/default.conf
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
redis_data:
================================================
FILE: docker-compose.yaml
================================================
version: '3.8'
services:
backend:
volumes:
- "./:/app"
build: .
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/wait-for-it.sh", "super__postgres:5432","-t","60","--","/app/entrypoint.sh"]
celery:
volumes:
- "./:/app"
- "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext"
build: .
depends_on:
- super__redis
- super__postgres
networks:
- super_network
command: ["/app/entrypoint_celery.sh"]
gui:
build:
context: ./gui
args:
NEXT_PUBLIC_API_BASE_URL: "/api"
networks:
- super_network
# volumes:
# - ./gui:/app
# - /app/node_modules/
# - /app/.next/
super__redis:
image: "redis/redis-stack-server:latest"
networks:
- super_network
# uncomment to expose redis port to host
# ports:
# - "6379:6379"
volumes:
- redis_data:/data
super__postgres:
image: "docker.io/library/postgres:15"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
# uncomment to expose postgres port to host
# ports:
# - "5432:5432"
proxy:
image: nginx:stable-alpine
ports:
- "3000:80"
networks:
- super_network
depends_on:
- backend
- gui
volumes:
- ./nginx/default.conf:/etc/nginx/conf.d/default.conf
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
redis_data:
================================================
FILE: entrypoint.sh
================================================
#!/bin/bash
# Downloads the tools from marketplace and external tool repositories
python superagi/tool_manager.py
# Install dependencies
./install_tool_dependencies.sh
# Run Alembic migrations
alembic upgrade head
# Start the app
exec uvicorn main:app --host 0.0.0.0 --port 8001 --reload
================================================
FILE: entrypoint_celery.sh
================================================
#!/bin/bash
# Downloads the tools
python superagi/tool_manager.py
# Install dependencies
./install_tool_dependencies.sh
exec celery -A superagi.worker worker --beat --loglevel=info
================================================
FILE: gui/.dockerignore
================================================
# Ignore everything
**
# Allow files and directories
!app
!pages
!public
!utils
!package.json
!next.config.js
!package-lock.json
!.eslintrc.json
!jsconfig.json
================================================
FILE: gui/.eslintrc.json
================================================
{
"extends": "next/core-web-vitals"
}
================================================
FILE: gui/Dockerfile
================================================
FROM node:18-alpine AS deps
RUN apk add --no-cache libc6-compat
WORKDIR /app
COPY package.json package-lock.json ./
RUN npm ci
# Rebuild the source code only when needed
FROM node:18-alpine AS builder
WORKDIR /app
COPY --from=deps /app/node_modules ./node_modules
COPY . .
ARG NEXT_PUBLIC_API_BASE_URL=/api
ENV NEXT_PUBLIC_API_BASE_URL=$NEXT_PUBLIC_API_BASE_URL
ARG NEXT_PUBLIC_MIXPANEL_AUTH_ID
ENV NEXT_PUBLIC_MIXPANEL_AUTH_ID=$NEXT_PUBLIC_MIXPANEL_AUTH_ID
EXPOSE 3000
CMD ["npm", "run", "dev"]
================================================
FILE: gui/DockerfileProd
================================================
FROM node:18-alpine AS deps
RUN apk add --no-cache libc6-compat
WORKDIR /app
COPY package.json package-lock.json ./
RUN npm ci --only=production
# Rebuild the source code only when needed
FROM node:18-alpine AS builder
WORKDIR /app
COPY --from=deps /app/node_modules ./node_modules
COPY . .
ARG NEXT_PUBLIC_API_BASE_URL=/api
ENV NEXT_PUBLIC_API_BASE_URL=$NEXT_PUBLIC_API_BASE_URL
RUN npm run build
# Production image, copy all the files and run next
FROM node:18-alpine AS runner
WORKDIR /app
ENV NODE_ENV production
RUN addgroup --system --gid 1001 supergroup
RUN adduser --system --uid 1001 superuser
COPY --from=builder /app/public ./public
COPY --from=builder /app/package.json ./package.json
# Automatically leverage output traces to reduce image size
# https://nextjs.org/docs/advanced-features/output-file-tracing
COPY --from=builder --chown=superuser:supergroup /app/.next/standalone ./
COPY --from=builder --chown=superuser:supergroup /app/.next/static ./.next/static
USER superuser
EXPOSE 3000
ENV PORT 3000
CMD ["node", "server.js"]
================================================
FILE: gui/README.md
================================================
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).
## Getting Started
First, run the development server:
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
You can start editing the page by modifying `app/page.js`. The page auto-updates as you edit the file.
This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details.
================================================
FILE: gui/app/globals.css
================================================
:root {
--max-width: 1100px;
--border-radius: 12px;
--font-mono: ui-monospace, Menlo, Monaco, 'Cascadia Mono', 'Segoe UI Mono',
'Roboto Mono', 'Oxygen Mono', 'Ubuntu Monospace', 'Source Code Pro',
'Fira Mono', 'Droid Sans Mono', 'Courier New', monospace;
--foreground-rgb: 0, 0, 0;
--background-start-rgb: 214, 219, 220;
--background-end-rgb: 255, 255, 255;
--primary-glow: conic-gradient(
from 180deg at 50% 50%,
#16abff33 0deg,
#0885ff33 55deg,
#54d6ff33 120deg,
#0071ff33 160deg,
transparent 360deg
);
--secondary-glow: radial-gradient(
rgba(255, 255, 255, 1),
rgba(255, 255, 255, 0)
);
--tile-start-rgb: 239, 245, 249;
--tile-end-rgb: 228, 232, 233;
--tile-border: conic-gradient(
#00000080,
#00000040,
#00000030,
#00000020,
#00000010,
#00000010,
#00000080
);
--callout-rgb: 238, 240, 241;
--callout-border-rgb: 172, 175, 176;
--card-rgb: 180, 185, 188;
--card-border-rgb: 131, 134, 135;
}
@media (prefers-color-scheme: dark) {
:root {
--foreground-rgb: 255, 255, 255;
--background-start-rgb: 0, 0, 0;
--background-end-rgb: 0, 0, 0;
--primary-glow: radial-gradient(rgba(1, 65, 255, 0.4), rgba(1, 65, 255, 0));
--secondary-glow: linear-gradient(
to bottom right,
rgba(1, 65, 255, 0),
rgba(1, 65, 255, 0),
rgba(1, 65, 255, 0.3)
);
--tile-start-rgb: 2, 13, 46;
--tile-end-rgb: 2, 5, 19;
--tile-border: conic-gradient(
#ffffff80,
#ffffff40,
#ffffff30,
#ffffff20,
#ffffff10,
#ffffff10,
#ffffff80
);
--callout-rgb: 20, 20, 20;
--callout-border-rgb: 108, 108, 108;
--card-rgb: 100, 100, 100;
--card-border-rgb: 200, 200, 200;
}
}
* {
box-sizing: border-box;
padding: 0;
margin: 0;
}
html,
body {
max-width: 100vw;
overflow-x: hidden;
}
body {
color: rgb(var(--foreground-rgb));
background: linear-gradient(
to bottom,
transparent,
rgb(var(--background-end-rgb))
)
rgb(var(--background-start-rgb));
}
a {
color: inherit;
text-decoration: none;
}
@media (prefers-color-scheme: dark) {
html {
color-scheme: dark;
}
}
================================================
FILE: gui/app/layout.js
================================================
import './globals.css'
export const metadata = {
title: 'Super AGI',
description: 'Generated by create next app',
}
export default function RootLayout({ children }) {
return (
{children}
)
}
================================================
FILE: gui/jsconfig.json
================================================
{
"compilerOptions": {
"paths": {
"@/*": ["./*"]
}
}
}
================================================
FILE: gui/next.config.js
================================================
/** @type {import('next').NextConfig} */
const nextConfig = {
assetPrefix: process.env.NODE_ENV === "production" ? "/" : "./",
output: 'standalone'
};
module.exports = nextConfig;
================================================
FILE: gui/package.json
================================================
{
"name": "super-agi",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev",
"build": "next build",
"start": "next start",
"lint": "next lint",
"export": "next export"
},
"dependencies": {
"axios": "^1.4.0",
"bootstrap": "^5.2.3",
"date-fns": "^2.30.0",
"date-fns-tz": "^2.0.0",
"echarts": "^5.4.2",
"echarts-for-react": "^3.0.2",
"eslint": "8.40.0",
"eslint-config-next": "13.4.2",
"js-cookie": "^3.0.5",
"jszip": "^3.10.1",
"mitt": "^3.0.0",
"mixpanel-browser": "^2.47.0",
"moment": "^2.29.4",
"moment-timezone": "^0.5.43",
"next": "13.4.2",
"react": "18.2.0",
"react-datetime": "^3.2.0",
"react-dom": "18.2.0",
"react-draggable": "^4.4.5",
"react-grid-layout": "^1.3.4",
"react-markdown": "^8.0.7",
"react-spinners": "^0.13.8",
"react-tippy": "^1.4.0",
"react-toastify": "^9.1.3"
}
}
================================================
FILE: gui/pages/Content/APM/Apm.module.css
================================================
.apm_dashboard_container {
display: flex;
flex-direction: column;
}
.apm_dashboard {
margin-top: 16px;
height: calc(100vh - 16vh);
overflow-y: auto;
}
================================================
FILE: gui/pages/Content/APM/ApmDashboard.js
================================================
import React, {useState, useEffect, useCallback, useRef} from 'react';
import Image from "next/image";
import style from "./Apm.module.css";
import 'react-toastify/dist/ReactToastify.css';
import {getActiveRuns, getAgentRuns, getAllAgents, getToolsUsage, getMetrics} from "@/pages/api/DashboardService";
import {formatNumber, formatTime, returnToolkitIcon} from "@/utils/utils";
import {BarGraph} from "./BarGraph.js";
import {WidthProvider, Responsive} from 'react-grid-layout';
import 'react-grid-layout/css/styles.css';
import 'react-resizable/css/styles.css';
import { Tooltip } from 'react-tippy';
const ResponsiveGridLayout = WidthProvider(Responsive);
export default function ApmDashboard() {
const [agentDetails, setAgentDetails] = useState([]);
const [tokenDetails, setTokenDetails] = useState([]);
const [runDetails, setRunDetails] = useState(0);
const [allAgents, setAllAgents] = useState([]);
const [dropdown1, setDropDown1] = useState(false);
const [dropdown2, setDropDown2] = useState(false);
const [dropdown3, setDropDown3] = useState(false);
const [selectedAgent, setSelectedAgent] = useState('Select an Agent');
const [selectedAgentIndex, setSelectedAgentIndex] = useState(-1);
const [selectedAgentRun, setSelectedAgentRun] = useState([]);
const [activeRuns, setActiveRuns] = useState([]);
const [selectedAgentDetails, setSelectedAgentDetails] = useState(null);
const [toolsUsed, setToolsUsed] = useState([]);
const [showToolTip, setShowToolTip] = useState(false);
const [toolTipIndex, setToolTipIndex] = useState(-1);
const initialLayout = [
{i: 'total_agents', x: 0, y: 0, w: 3, h: 1.5},
{i: 'total_tokens', x: 3, y: 0, w: 3, h: 1.5},
{i: 'total_runs', x: 6, y: 0, w: 3, h: 1.5},
{i: 'active_runs', x: 9, y: 0, w: 3, h: 2},
{i: 'most_used_tools', x: 9, y: 1, w: 3, h: 2},
{i: 'models_by_agents', x: 0, y: 1, w: 3, h: 2.5},
{i: 'runs_by_model', x: 3, y: 1, w: 3, h: 2.5},
{i: 'tokens_by_model', x: 6, y: 1, w: 3, h: 2.5},
{i: 'agent_details', x: 0, y: 2, w: 12, h: 2.5},
{i: 'total_tokens_consumed', x: 0, y: 3, w: 4, h: 2},
{i: 'total_calls_made', x: 4, y: 3, w: 4, h: 2},
{i: 'tokens_consumed_per_call', x: 8, y: 3, w: 4, h: 2},
];
const storedLayout = localStorage.getItem('myLayoutKey');
const [layout, setLayout] = useState(storedLayout !== null ? JSON.parse(storedLayout) : initialLayout);
const firstUpdate = useRef(true);
const onLayoutChange = (currentLayout) => {
setLayout(currentLayout);
};
const onClickLayoutChange = () => {
localStorage.setItem('myLayoutKey', JSON.stringify(initialLayout));
setLayout(initialLayout);
}
useEffect(() => {
if (!firstUpdate.current) {
localStorage.setItem('myLayoutKey', JSON.stringify(layout));
} else {
firstUpdate.current = false;
}
}, [layout]);
const assignDefaultDataPerModel = (data, modelList) => {
const modelsInData = data.map(item => item.name);
modelList.forEach((model) => {
if (!modelsInData.includes(model)) {
data.push({name: model, value: 0});
}
});
};
useEffect(() => {
const fetchData = async () => {
try {
const [metricsResponse, agentsResponse, activeRunsResponse, toolsUsageResponse] = await Promise.all([getMetrics(), getAllAgents(), getActiveRuns(), getToolsUsage()]);
const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat'];
assignDefaultDataPerModel(metricsResponse.data.agent_details.model_metrics, models);
assignDefaultDataPerModel(metricsResponse.data.tokens_details.model_metrics, models);
assignDefaultDataPerModel(metricsResponse.data.run_details.model_metrics, models);
setAgentDetails(metricsResponse.data.agent_details);
setTokenDetails(metricsResponse.data.tokens_details);
setRunDetails(metricsResponse.data.run_details);
setAllAgents(agentsResponse.data.agent_details);
setActiveRuns(activeRunsResponse.data);
setToolsUsed(toolsUsageResponse.data);
} catch (error) {
console.log(`Error in fetching data: ${error}`);
}
}
fetchData();
const interval = setInterval(fetchData, 10000);
return () => clearInterval(interval);
}, []);
useEffect(() => {
console.log(toolsUsed)
}, [toolsUsed]);
const handleSelectedAgent = useCallback((index, name) => {
setDropDown1(false)
setDropDown2(false)
setDropDown3(false)
setSelectedAgent(name)
setSelectedAgentIndex(index)
const agentDetails = allAgents.find(agent => agent.agent_id === index);
setSelectedAgentDetails(agentDetails);
getAgentRuns(index).then((response) => {
const data = response.data;
setSelectedAgentRun(data);
}).catch((error) => console.error(`Error in fetching agent runs: ${error}`));
}, [allAgents]);
useEffect(() => handleSelectedAgent(selectedAgentIndex, selectedAgent), [allAgents]);
useEffect(() => {
if (allAgents.length > 0 && selectedAgent === 'Select an Agent') {
const lastAgent = allAgents[allAgents.length - 1];
handleSelectedAgent(lastAgent.agent_id, lastAgent.name);
}
}, [allAgents, selectedAgent, handleSelectedAgent]);
const setToolTipState = (state, index) => {
setShowToolTip(state)
setToolTipIndex(index)
}
return (
Agent Performance Monitoring
{/**/}
Total Agents
{formatNumber(agentDetails.total_agents)}
Total tokens consumed
{formatNumber(tokenDetails.total_tokens)}
Total runs
{formatNumber(runDetails.total_runs)}
Number of Agents per model
{agentDetails.model_metrics && agentDetails.model_metrics.length > 0
? <>
Models
>
:
No Agents Found
}
Number of Runs per Model
{runDetails.model_metrics && runDetails.model_metrics.length > 0
? <>
Models
>
:
No Agents Found
}
Total Tokens consumed by models
{tokenDetails.model_metrics && tokenDetails.model_metrics.length > 0
? <>
Models
>
:
No Agents Found
}
Most used tools
{toolsUsed.length === 0 ?
No Used Tools Found
:
Tool
Agents
Calls
{toolsUsed.map((tool, index) => (
{tool.tool_name}
{tool.unique_agents}
{tool.total_usage}
))}
}
Agent Overview
{allAgents.length === 0 ?
{selectedAgent === 'Select an Agent' ? 'Please Select an Agent' :
No Runs found for {selectedAgent}}
openNewTab(-5, "new model", "Add_Model", false)} className="custom_select_option horizontal_container mxw_100 padding_12_14 gap_6 bt_white">
Add new custom model
The expiry date of the run
is {(new Date(`${expiryDate}Z`).toLocaleString()).substring(0, 10) == "Invalid Da" ? expiryDate : (new Date(`${expiryDate}Z`).toLocaleString()).substring(0, 10)}
The {templateData.provider} auth token is not added to your settings. In order to start using the model, you need to add the auth token to your settings. You can find the auth token in the {templateData.provider} dashboard.
}
{templateData.provider === 'Hugging Face' &&
In order to get the endpoint for this model, you will need to deploy it on your Replicate dashboard. Once you have deployed your model on Hugging Face, you will be able to access the endpoint through the Hugging Face dashboard. The endpoint is a URL that you can use to send requests to your model.
Your secret API keys are sensitive pieces of information that should be kept confidential. Do not share them with anyone, and do not expose them in any way. If your secret API keys are compromised, someone could use them to access your API and make unauthorized changes to your data. This secret key is only displayed once for security reasons. Please save it in a secure location where you can access it easily.
By continuing, you agree to Super AGI’s Terms of Service and Privacy Policy, and to receive important
updates.
:
{loadingText}
}
) : true}
);
}
================================================
FILE: gui/pages/api/DashboardService.js
================================================
import api from './apiConfig';
export const getOrganisation = (userId) => {
return api.get(`/organisations/get/user/${userId}`);
};
export const getGithubClientId = () => {
return api.get(`/get/github_client_id`);
};
export const addUser = (userData) => {
return api.post(`/users/add`, userData);
};
export const getProject = (organisationId) => {
return api.get(`/projects/get/organisation/${organisationId}`);
};
export const getAgents = (projectId) => {
return api.get(`/agents/get/project/${projectId}`);
};
export const getToolKit = () => {
return api.get(`/toolkits/get/local/list`);
};
export const getTools = () => {
return api.get(`/tools/list`);
};
export const getAgentDetails = (agentId, agentExecutionId) => {
return api.get(`/agent_executions_configs/details/agent_id/${agentId}/agent_execution_id/${agentExecutionId}`);
};
export const getAgentExecutions = (agentId) => {
return api.get(`/agentexecutions/get/agent/${agentId}`);
};
export const getExecutionFeeds = (executionId) => {
return api.get(`/agentexecutionfeeds/get/execution/${executionId}`);
};
export const getExecutionTasks = (executionId) => {
return api.get(`/agentexecutionfeeds/get/tasks/${executionId}`);
};
export const createAgent = (agentData, scheduledCreate) => {
return api.post(scheduledCreate ? `/agents/schedule` : `/agents/create`, agentData);
};
export const addAgentRun = (agentData) => {
return api.post( `/agentexecutions/add_run`, agentData);
};
export const addTool = (toolData) => {
return api.post(`/toolkits/get/local/install`, toolData);
};
export const updateExecution = (executionId, executionData) => {
return api.put(`/agentexecutions/update/${executionId}`, executionData);
};
export const editAgentTemplate = (agentTemplateId, agentTemplateData) => {
return api.put(`/agent_templates/update_agent_template/${agentTemplateId}`, agentTemplateData)
};
export const addExecution = (executionData) => {
return api.post(`/agentexecutions/add`, executionData);
};
export const getResources = (agentId) => {
return api.get(`/resources/get/all/${agentId}`);
};
export const getLastActiveAgent = (projectId) => {
return api.get(`/agentexecutions/get/latest/agent/project/${projectId}`);
};
export const uploadFile = (agentId, formData) => {
return api.post(`/resources/add/${agentId}`, formData);
};
export const validateAccessToken = () => {
return api.get(`/validate-access-token`);
};
export const validateLLMApiKey = (model_source, model_api_key) => {
return api.post(`/validate-llm-api-key`, {model_source, model_api_key});
};
export const checkEnvironment = () => {
return api.get(`/configs/get/env`);
};
export const getOrganisationConfig = (organisationId, key) => {
return api.get(`/configs/get/organisation/${organisationId}/key/${key}`);
};
export const updateOrganisationConfig = (organisationId, configData) => {
return api.post(`/configs/add/organisation/${organisationId}`, configData);
};
export const fetchAgentTemplateList = () => {
return api.get('/agent_templates/list?template_source=marketplace');
};
export const fetchAgentTemplateDetails = (templateId) => {
return api.get(`/agent_templates/get/${templateId}`);
};
export const getToolConfig = (toolKitName) => {
return api.get(`/tool_configs/get/toolkit/${toolKitName}`);
};
export const updateToolConfig = (toolKitName, configData) => {
return api.post(`/tool_configs/add/${toolKitName}`, configData);
};
export const fetchAgentTemplateListLocal = () => {
return api.get('/agent_templates/list?template_source=local');
};
export const saveAgentAsTemplate = (agentId, executionId) => {
return api.post(`/agent_templates/save_agent_as_template/agent_id/${agentId}/agent_execution_id/${executionId}`);
};
export const fetchAgentTemplateConfig = (templateId) => {
return api.get(`/agent_templates/get/${templateId}?template_source=marketplace`);
};
export const installAgentTemplate = (templateId) => {
return api.post(`/agent_templates/download?agent_template_id=${templateId}`);
};
export const fetchAgentTemplateConfigLocal = (templateId) => {
return api.get(`/agent_templates/agent_config?agent_template_id=${templateId}`);
};
export const updatePermissions = (permissionId, data) => {
return api.put(`/agentexecutionpermissions/update/status/${permissionId}`, data)
};
export const deleteAgent = (agentId) => {
return api.put(`/agents/delete/${agentId}`)
};
export const authenticateGoogleCred = (toolKitId) => {
return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`);
};
export const authenticateTwitterCred = (toolKitId) => {
return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`);
};
export const sendTwitterCreds = (twitter_creds) => {
return api.post(`/twitter/send_twitter_creds/${twitter_creds}`);
};
export const sendGoogleCreds = (google_creds, toolkit_id) => {
return api.post(`/google/send_google_creds/toolkit_id/${toolkit_id}`, google_creds);
};
export const fetchToolTemplateList = () => {
return api.get(`/toolkits/get/list?page=0`);
};
export const fetchKnowledgeTemplateList = () => {
return api.get(`/knowledges/get/list?page=0`);
};
export const fetchToolTemplateOverview = (toolTemplateName) => {
return api.get(`/toolkits/marketplace/readme/${toolTemplateName}`);
};
export const updateMarketplaceToolTemplate = (templateName) => {
return api.put(`/toolkits/update/${templateName}`);
};
export const installToolkitTemplate = (templateName) => {
return api.get(`/toolkits/get/install/${templateName}`);
};
export const checkToolkitUpdate = (templateName) => {
return api.get(`/toolkits/check_update/${templateName}`);
};
export const getExecutionDetails = (executionId, agentId) => {
return api.get(`/agent_executions_configs/details/agent/${agentId}/agent_execution/${executionId}`);
};
export const stopSchedule = (agentId) => {
return api.post(`/agents/stop/schedule?agent_id=${agentId}`);
};
export const createAndScheduleRun = (requestData) => {
return api.post(`/agentexecutions/schedule`, requestData);
};
export const agentScheduleComponent = (agentId) => {
return api.get(`/agents/get/schedule_data/${agentId}`);
};
export const updateSchedule = (requestData) => {
return api.put(`/agents/edit/schedule`, requestData);
};
export const getDateTime = (agentId) => {
return api.get(`/agents/get/schedule_data/${agentId}`);
};
export const getMetrics = () => {
return api.get(`/analytics/metrics`)
};
export const getAllAgents = () => {
return api.get(`/analytics/agents/all`)
};
export const getAgentRuns = (agent_id) => {
return api.get(`analytics/agents/${agent_id}`);
};
export const getActiveRuns = () => {
return api.get(`analytics/runs/active`);
};
export const getToolsUsage = () => {
return api.get(`analytics/tools/used`);
};
export const modelInfo = (model) => {
return api.get(`analytics/model_details/${model}`)
}
export const getLlmModels = () => {
return api.get(`organisations/llm_models`);
};
export const getAgentWorkflows = () => {
return api.get(`organisations/agent_workflows`);
};
export const fetchVectorDBList = () => {
return api.get(`/vector_dbs/get/list`);
};
export const getVectorDatabases = () => {
return api.get(`/vector_dbs/user/list`);
};
export const getVectorDBDetails = (vectorDBId) => {
return api.get(`/vector_dbs/db/details/${vectorDBId}`);
};
export const deleteVectorDB = (vectorDBId) => {
return api.post(`/vector_dbs/delete/${vectorDBId}`);
};
export const updateVectorDB = (vectorDBId, newIndices) => {
return api.put(`/vector_dbs/update/vector_db/${vectorDBId}`, newIndices);
};
export const connectPinecone = (pineconeData) => {
return api.post(`/vector_dbs/connect/pinecone`, pineconeData);
};
export const connectQdrant = (qdrantData) => {
return api.post(`/vector_dbs/connect/qdrant`, qdrantData);
};
export const connectWeaviate = (weaviateData) => {
return api.post(`/vector_dbs/connect/weaviate`, weaviateData);
};
export const getKnowledge = () => {
return api.get(`/knowledges/user/list`);
};
export const getKnowledgeDetails = (knowledgeId) => {
return api.get(`/knowledges/user/get/details/${knowledgeId}`);
};
export const deleteCustomKnowledge = (knowledgeId) => {
return api.post(`/knowledges/delete/${knowledgeId}`);
};
export const deleteMarketplaceKnowledge = (knowledgeName) => {
return api.post(`/knowledges/uninstall/${knowledgeName}`);
};
export const addUpdateKnowledge = (knowledgeData) => {
return api.post(`/knowledges/add_or_update/data`, knowledgeData);
};
export const getValidIndices = () => {
return api.get(`/vector_db_indices/user/valid_indices`);
};
export const getValidMarketplaceIndices = (knowledgeName) => {
return api.get(`/vector_db_indices/marketplace/valid_indices/${knowledgeName}`);
};
export const fetchKnowledgeTemplateOverview = (knowledgeName) => {
return api.get(`/knowledges/marketplace/get/details/${knowledgeName}`);
};
export const installKnowledgeTemplate = (knowledgeName, indexId) => {
return api.get(`/knowledges/install/${knowledgeName}/index/${indexId}`);
};
export const createApiKey = (apiName) => {
return api.post(`/api-keys`, apiName);
};
export const getApiKeys = () => {
return api.get(`/api-keys`);
};
export const editApiKey = (apiDetails) => {
return api.put(`/api-keys`, apiDetails);
};
export const deleteApiKey = (apiId) => {
return api.delete(`/api-keys/${apiId}`);
};
export const saveWebhook = (webhook) => {
return api.post(`/webhook/add`, webhook);
};
export const getWebhook = () => {
return api.get(`/webhook/get`);
};
export const editWebhook = (webhook_id, webook_data) => {
return api.post(`/webhook/edit/${webhook_id}`, webook_data);
};
export const publishToMarketplace = (executionId) => {
return api.post(`/agent_templates/publish_template/agent_execution_id/${executionId}`);
};
export const storeApiKey = (model_provider, model_api_key) => {
return api.post(`/models_controller/store_api_keys`, {model_provider, model_api_key});
}
export const fetchApiKeys = () => {
return api.get(`/models_controller/get_api_keys`);
}
export const fetchApiKey = (model_provider) => {
return api.get(`/models_controller/get_api_key?model_provider=${model_provider}`);
}
export const verifyEndPoint = (model_api_key, end_point, model_provider) => {
return api.get(`/models_controller/verify_end_point`, {
params: { model_api_key, end_point, model_provider }
});
}
export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => {
return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length});
}
export const testModel = () => {
return api.get(`/models_controller/test_local_llm`);
}
export const fetchModels = () => {
return api.get(`/models_controller/fetch_models`);
}
export const fetchModel = (model_id) => {
return api.get(`/models_controller/fetch_model/${model_id}`);
}
export const fetchModelData = (model) => {
return api.post(`/models_controller/fetch_model_data`, { model: model })
}
export const fetchMarketPlaceModel = () => {
return api.get(`/models_controller/get/list`)
}
export const getToolMetrics = (toolName) => {
return api.get(`analytics/tools/${toolName}/usage`)
}
export const getToolLogs = (toolName) => {
return api.get(`analytics/tools/${toolName}/logs`)
}
export const publishTemplateToMarketplace = (agentData) => {
return api.post(`/agent_templates/publish_template`, agentData);
};
export const getKnowledgeMetrics = (knowledgeName) => {
return api.get(`analytics/knowledge/${knowledgeName}/usage`)
}
export const getKnowledgeLogs = (knowledgeName) => {
return api.get(`analytics/knowledge/${knowledgeName}/logs`)
}
export const getFirstSignup = (source) => {
return api.post(`/users/first_login_source/${source}`,);
};
================================================
FILE: gui/pages/api/apiConfig.js
================================================
import axios from 'axios';
import Cookies from "js-cookie";
const GITHUB_CLIENT_ID = process.env.GITHUB_CLIENT_ID;
const API_BASE_URL = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:8001';
const GOOGLE_ANALYTICS_MEASUREMENT_ID = process.env.GOOGLE_ANALYTICS_MEASUREMENT_ID;
const GOOGLE_ANALYTICS_API_SECRET = process.env.GOOGLE_ANALYTICS_API_SECRET;
const MIXPANEL_AUTH_ID = process.env.NEXT_PUBLIC_MIXPANEL_AUTH_ID
export const baseUrl = () => {
return API_BASE_URL;
};
export const githubClientId = () => {
return GITHUB_CLIENT_ID;
};
export const analyticsMeasurementId = () => {
return GOOGLE_ANALYTICS_MEASUREMENT_ID;
};
export const analyticsApiSecret = () => {
return GOOGLE_ANALYTICS_API_SECRET;
};
export const mixpanelId = () => {
return MIXPANEL_AUTH_ID;
};
const api = axios.create({
baseURL: API_BASE_URL,
headers: {
common: {
'Content-Type': 'application/json',
},
},
});
api.interceptors.request.use(config => {
if (typeof window !== 'undefined') {
// const accessToken = localStorage.getItem("accessToken");
const accessToken = Cookies.get("accessToken");
if (accessToken) {
config.headers['Authorization'] = `Bearer ${accessToken}`;
}
}
return config;
});
export default api;
================================================
FILE: gui/utils/eventBus.js
================================================
import mitt from 'mitt';
const emitter = mitt();
export const EventBus = {
on: emitter.on,
off: emitter.off,
emit: emitter.emit,
};
================================================
FILE: gui/utils/utils.js
================================================
import {formatDistanceToNow, format, addMinutes} from 'date-fns';
import {utcToZonedTime} from 'date-fns-tz';
import {baseUrl, analyticsMeasurementId, analyticsApiSecret, mixpanelId} from "@/pages/api/apiConfig";
import {EventBus} from "@/utils/eventBus";
import JSZip from "jszip";
import moment from 'moment';
import mixpanel from 'mixpanel-browser'
import Cookies from "js-cookie";
const toolkitData = {
'Jira Toolkit': '/images/jira_icon.svg',
'Email Toolkit': '/images/gmail_icon.svg',
'Google Calendar Toolkit': '/images/google_calender_icon.svg',
'GitHub Toolkit': '/images/github_icon.svg',
'Google Search Toolkit': '/images/google_search_icon.svg',
'Searx Toolkit': '/images/searx_icon.svg',
'Slack Toolkit': '/images/slack_icon.svg',
'Web Scraper Toolkit': '/images/webscraper_icon.svg',
'Web Scrapper Toolkit': '/images/webscraper_icon.svg',
'Twitter Toolkit': '/images/twitter_icon.svg',
'Google SERP Toolkit': '/images/google_serp_icon.svg',
'File Toolkit': '/images/filemanager_icon.svg',
'CodingToolkit': '/images/superagi_logo.png',
'Thinking Toolkit': '/images/superagi_logo.png',
'Image Generation Toolkit': '/images/superagi_logo.png',
'DuckDuckGo Search Toolkit': '/images/duckduckgo_icon.png',
'Instagram Toolkit': '/images/instagram.png',
'Knowledge Search Toolkit': '/images/knowledeg_logo.png',
'Notion Toolkit': '/images/notion_logo.png',
'ApolloToolkit': '/images/apollo_logo.png',
'Google Analytics Toolkit': '/images/google_analytics_logo.png'
};
export const getUserTimezone = () => {
return Intl.DateTimeFormat().resolvedOptions().timeZone;
}
export const convertToGMT = (dateTime) => {
if (!dateTime) {
return null;
}
return moment.utc(dateTime).format('YYYY-MM-DD HH:mm:ss');
};
export const formatTimeDifference = (timeDifference) => {
const units = ['years', 'months', 'days', 'hours', 'minutes'];
const singularUnits = ['year', 'month', 'day', 'hour', 'minute'];
for (let i = 0; i < units.length; i++) {
const unit = units[i];
if (timeDifference[unit] !== 0) {
if (unit === 'minutes') {
return `${timeDifference[unit]} ${timeDifference[unit] === 1 ? singularUnits[i] : unit} ago`;
} else {
return `${timeDifference[unit]} ${timeDifference[unit] === 1 ? singularUnits[i] : unit} ago`;
}
}
}
return 'Just now';
};
export const formatNumber = (number) => {
if (number === null || number === undefined || number === 0) {
return '0';
}
const suffixes = ['', 'k', 'M', 'B', 'T'];
const magnitude = Math.floor(Math.log10(number) / 3);
const scaledNumber = number / Math.pow(10, magnitude * 3);
const suffix = suffixes[magnitude];
if (scaledNumber % 1 === 0) {
return scaledNumber.toFixed(0) + suffix;
}
return scaledNumber.toFixed(1) + suffix;
};
export const formatTime = (lastExecutionTime) => {
try {
const parsedTime = new Date(lastExecutionTime + 'Z'); // append 'Z' to indicate UTC
if (isNaN(parsedTime.getTime())) {
throw new Error('Invalid time value');
}
const timeZone = 'Asia/Kolkata';
const zonedTime = utcToZonedTime(parsedTime, timeZone);
return formatDistanceToNow(zonedTime, {
addSuffix: true,
includeSeconds: true
}).replace(/about\s/, '')
.replace(/minutes?/, 'min')
.replace(/hours?/, 'hrs')
.replace(/days?/, 'day')
.replace(/weeks?/, 'week');
} catch (error) {
console.error('Error formatting time:', error);
return 'Invalid Time';
}
};
export const formatBytes = (bytes, decimals = 2) => {
if (bytes === 0) {
return '0 Bytes';
}
const k = 1024;
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
const formattedValue = parseFloat((bytes / Math.pow(k, i)).toFixed(decimals));
return `${formattedValue} ${sizes[i]}`;
};
export const downloadFile = (fileId, fileName = null) => {
// const authToken = localStorage.getItem('accessToken');
const authToken = Cookies.get("accessToken");
const url = `${baseUrl()}/resources/get/${fileId}`;
const env = localStorage.getItem('applicationEnvironment');
if (env === 'PROD') {
const headers = {
Authorization: `Bearer ${authToken}`,
};
return fetch(url, {headers})
.then((response) => response.blob())
.then((blob) => {
if (fileName) {
const fileUrl = window.URL.createObjectURL(blob);
const anchorElement = document.createElement('a');
anchorElement.href = fileUrl;
anchorElement.download = fileName;
anchorElement.click();
window.URL.revokeObjectURL(fileUrl);
} else {
return blob;
}
})
.catch((error) => {
console.error('Error downloading file:', error);
});
} else {
if (fileName) {
window.open(url, '_blank');
} else {
return fetch(url)
.then((response) => response.blob())
.catch((error) => {
console.error('Error downloading file:', error);
});
}
}
};
export const downloadAllFiles = (files, run_name) => {
const zip = new JSZip();
const promises = [];
const fileNamesCount = {};
files.forEach((file, index) => {
fileNamesCount[file.name]
? fileNamesCount[file.name]++
: (fileNamesCount[file.name] = 1);
let modifiedFileName = file.name;
if (fileNamesCount[file.name] > 1) {
const fileExtensionIndex = file.name.lastIndexOf(".");
const name = file.name.substring(0, fileExtensionIndex);
const extension = file.name.substring(fileExtensionIndex + 1);
modifiedFileName = `${name} (${fileNamesCount[file.name] - 1}).${extension}`;
}
const promise = downloadFile(file.id)
.then((blob) => {
const fileBlob = new Blob([blob], {type: file.type});
zip.file(modifiedFileName, fileBlob);
})
.catch((error) => {
console.error("Error downloading file:", error);
});
promises.push(promise);
});
Promise.all(promises)
.then(() => {
zip.generateAsync({type: "blob"})
.then((content) => {
const now = new Date();
const timestamp = `${now.getFullYear()}-${("0" + (now.getMonth() + 1)).slice(-2)}-${("0" + now.getDate()).slice(-2)}_${now.getHours()}:${now.getMinutes()}:${now.getSeconds()}`.replace(/:/g, '-');
const zipFilename = `${run_name}_${timestamp}.zip`;
const downloadLink = document.createElement("a");
downloadLink.href = URL.createObjectURL(content);
downloadLink.download = zipFilename;
downloadLink.click();
})
.catch((error) => {
console.error("Error generating zip:", error);
});
});
};
export const refreshUrl = () => {
if (typeof window === 'undefined') {
return;
}
const {origin, pathname} = window.location;
const urlWithoutToken = origin + pathname;
window.history.replaceState({}, document.title, urlWithoutToken);
};
export const loadingTextEffect = (loadingText, setLoadingText, timer) => {
const text = loadingText;
let dots = '';
const interval = setInterval(() => {
dots = dots.length < 3 ? dots + '.' : '';
setLoadingText(`${text}${dots}`);
}, timer);
return () => clearInterval(interval)
};
export const openNewTab = (id, name, contentType, hasInternalId = false) => {
EventBus.emit('openNewTab', {
element: {id: id, name: name, contentType: contentType, internalId: hasInternalId ? createInternalId() : 0}
});
};
export const removeTab = (id, name, contentType, internalId) => {
EventBus.emit('removeTab', {
element: {id: id, name: name, contentType: contentType, internalId: internalId}
});
};
export const setLocalStorageValue = (key, value, stateFunction) => {
stateFunction(value);
localStorage.setItem(key, value);
};
export const setLocalStorageArray = (key, value, stateFunction) => {
stateFunction(value);
const arrayString = JSON.stringify(value);
localStorage.setItem(key, arrayString);
};
const getInternalIds = () => {
const internal_ids = localStorage.getItem("agi_internal_ids");
return internal_ids ? JSON.parse(internal_ids) : [];
};
const removeAgentInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem("agent_create_click_" + String(internalId));
localStorage.removeItem("agent_name_" + String(internalId));
localStorage.removeItem("agent_description_" + String(internalId));
localStorage.removeItem("agent_goals_" + String(internalId));
localStorage.removeItem("agent_instructions_" + String(internalId));
localStorage.removeItem("agent_constraints_" + String(internalId));
localStorage.removeItem("agent_model_" + String(internalId));
localStorage.removeItem("agent_type_" + String(internalId));
localStorage.removeItem("tool_names_" + String(internalId));
localStorage.removeItem("tool_ids_" + String(internalId));
localStorage.removeItem("agent_rolling_window_" + String(internalId));
localStorage.removeItem("agent_database_" + String(internalId));
localStorage.removeItem("agent_permission_" + String(internalId));
localStorage.removeItem("agent_exit_criterion_" + String(internalId));
localStorage.removeItem("agent_iterations_" + String(internalId));
localStorage.removeItem("agent_step_time_" + String(internalId));
localStorage.removeItem("advanced_options_" + String(internalId));
localStorage.removeItem("has_LTM_" + String(internalId));
localStorage.removeItem("has_resource_" + String(internalId));
localStorage.removeItem("agent_files_" + String(internalId));
localStorage.removeItem("agent_start_time_" + String(internalId));
localStorage.removeItem("agent_expiry_date_" + String(internalId));
localStorage.removeItem("agent_expiry_type_" + String(internalId));
localStorage.removeItem("agent_expiry_runs_" + String(internalId));
localStorage.removeItem("agent_time_unit_" + String(internalId));
localStorage.removeItem("agent_time_value_" + String(internalId));
localStorage.removeItem("agent_is_recurring_" + String(internalId));
localStorage.removeItem("is_agent_template_" + String(internalId));
localStorage.removeItem("agent_template_id_" + String(internalId));
localStorage.removeItem("agent_knowledge_" + String(internalId));
localStorage.removeItem("agent_knowledge_id_" + String(internalId));
localStorage.removeItem("is_editing_agent_" + String(internalId));
}
};
const removeAddToolkitInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('tool_github_' + String(internalId));
}
};
const removeToolkitsInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('toolkit_tab_' + String(internalId));
localStorage.removeItem('api_configs_' + String(internalId));
}
};
const removeKnowledgeInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('knowledge_name_' + String(internalId));
localStorage.removeItem('knowledge_description_' + String(internalId));
localStorage.removeItem('knowledge_index_' + String(internalId));
}
}
const removeAddDatabaseInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('add_database_tab_' + String(internalId));
localStorage.removeItem('selected_db_' + String(internalId));
localStorage.removeItem('db_name_' + String(internalId));
localStorage.removeItem('db_collections_' + String(internalId));
localStorage.removeItem('pincone_api_' + String(internalId));
localStorage.removeItem('pinecone_env_' + String(internalId));
localStorage.removeItem('qdrant_api_' + String(internalId));
localStorage.removeItem('qdrant_url_' + String(internalId));
localStorage.removeItem('qdrant_port_' + String(internalId));
}
}
const removeDatabaseInternalId = (internalId) => {
let idsArray = getInternalIds();
const internalIdIndex = idsArray.indexOf(internalId);
if (internalIdIndex !== -1) {
idsArray.splice(internalIdIndex, 1);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('db_details_collections_' + String(internalId));
}
}
export const resetLocalStorage = (contentType, internalId) => {
switch (contentType) {
case 'Create_Agent':
removeAgentInternalId(internalId);
break;
case 'Add_Toolkit':
removeAddToolkitInternalId(internalId);
break;
case 'Marketplace':
localStorage.removeItem('marketplace_tab');
localStorage.removeItem('market_item_clicked');
localStorage.removeItem('market_detail_type');
localStorage.removeItem('market_item');
break;
case 'Toolkits':
removeToolkitsInternalId(internalId);
break;
case 'Knowledge':
removeKnowledgeInternalId(internalId);
break;
case 'Add_Knowledge':
removeKnowledgeInternalId(internalId);
break;
case 'Add_Database':
removeAddDatabaseInternalId(internalId);
break;
case 'Database':
removeDatabaseInternalId(internalId);
break;
case 'Settings':
localStorage.removeItem('settings_tab');
break;
default:
break;
}
};
export const createInternalId = () => {
let newId = 1;
if (typeof window !== 'undefined') {
let idsArray = getInternalIds();
let found = false;
for (let i = 1; !found; i++) {
if (!idsArray.includes(i)) {
newId = i;
found = true;
}
}
idsArray.push(newId);
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
}
return newId;
};
export const returnToolkitIcon = (toolkitName) => {
return toolkitData[toolkitName] || '/images/custom_tool.svg';
};
export const returnResourceIcon = (file) => {
const fileType = file.type;
switch (true) {
case fileType.includes('image'):
return '/images/img_file.svg';
case fileType === 'application/pdf':
return '/images/pdf_file.svg';
case fileType === 'application/txt' || fileType === 'text/plain':
return '/images/txt_file.svg';
default:
return '/images/default_file.svg';
}
};
export const returnDatabaseIcon = (database) => {
const dbTypeIcons = {
'Pinecone': '/images/pinecone.svg',
'Qdrant': '/images/qdrant.svg',
'Weaviate' : '/images/weaviate.svg'
};
return dbTypeIcons[database]
};
export const convertToTitleCase = (str) => {
if (!str) {
return '';
}
const words = str.toLowerCase().split('_');
const capitalizedWords = words.map((word) => word.charAt(0).toUpperCase() + word.slice(1));
return capitalizedWords.join(' ');
};
export const preventDefault = (e) => {
e.stopPropagation();
};
export const excludedToolkits = () => {
return ["Thinking Toolkit", "Human Input Toolkit", "Resource Toolkit"];
}
export const getFormattedDate = (data) => {
let date = new Date(data);
const year = date.getFullYear();
const day = date.getDate();
const months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"];
const month = months[date.getMonth()];
return `${day} ${month} ${year}`;
}
export const modelIcon = (model) => {
const icons = {
'Hugging Face': '/images/huggingface_logo.svg',
'Google Palm': '/images/google_palm_logo.svg',
'Replicate': '/images/replicate_logo.svg',
'OpenAI': '/images/openai_logo.svg',
}
return icons[model];
}
export const modelGetAuth = (modelProvider) => {
const externalLinks = {
'Replicate': 'https://replicate.com/account/api-tokens',
'Hugging Face': 'https://huggingface.co/settings/tokens',
'OpenAI': 'https://platform.openai.com/account/api-keys',
'Google Palm': 'https://developers.generativeai.google/products/palm',
}
return externalLinks[modelProvider]
}
export const formatDateTime = (dateTimeString) => {
const date = new Date(dateTimeString);
const adjustedDate = addMinutes(addMinutes(date, 5 * 60), 30);
const formattedDate = format(adjustedDate, 'd MMM yyyy HH:mm');
return formattedDate;
};
export const convertWaitingPeriod = (waitingPeriod) => {
let convertedValue = waitingPeriod;
let unit = 'seconds';
if (convertedValue >= 60 && convertedValue < 3600) {
convertedValue = Math.floor(convertedValue / 60);
unit = 'minutes';
} else if (convertedValue >= 3600 && convertedValue < 86400) {
convertedValue = Math.floor(convertedValue / 3600);
unit = 'hours';
} else if (convertedValue >= 86400 && convertedValue < 604800) {
convertedValue = Math.floor(convertedValue / 86400);
unit = 'days';
} else if (convertedValue >= 604800) {
convertedValue = Math.floor(convertedValue / 604800);
unit = 'weeks';
}
return convertedValue + ' ' + unit;
}
export const getUTMParametersFromURL = () => {
const params = new URLSearchParams(window.location.search);
const utmParams = {
utm_source: params.get('utm_source') || '',
utm_medium: params.get('utm_medium') || '',
utm_campaign: params.get('utm_campaign') || '',
};
if (!utmParams.utm_source && !utmParams.utm_medium && !utmParams.utm_campaign) {
return null;
}
return utmParams;
}
export const getUserClick = (event, props) => {
const env = localStorage.getItem('applicationEnvironment');
if(env === 'PROD' && mixpanelId()){
mixpanel.track(event, props)
}
}
export const sendGAEvent = async (client, eventName, params) => {
const measurement_id = analyticsMeasurementId();
const api_secret = analyticsApiSecret();
await fetch(`https://www.google-analytics.com/mp/collect?measurement_id=${measurement_id}&api_secret=${api_secret}`, {
method: "POST",
body: JSON.stringify({
client_id: client,
events: [{
name: eventName,
params: params
}]
})
});
}
================================================
FILE: install_tool_dependencies.sh
================================================
#!/bin/bash
# Update and upgrade apt settings and apps
apt update && apt upgrade -y
xargs apt install -y < /app/requirements_apt.txt
# Run the project's main requirements.txt
pip install -r /app/requirements.txt
for tool in /app/superagi/tools/* /app/superagi/tools/external_tools/* /app/superagi/tools/marketplace_tools/* ; do
# Loop through the tools directories and install their apt_requirements.txt if they exist
if [ -d "$tool" ] && [ -f "$tool/requirements_apt.txt" ]; then
echo "Installing apt requirements for tool: $(basename "$tool")"
xargs apt install -y < "$tool/requirements_apt.txt"
fi
# Loop through the tools directories and install their requirements.txt if they exist
if [ -d "$tool" ] && [ -f "$tool/requirements.txt" ]; then
echo "Installing requirements for tool: $(basename "$tool")"
pip install -r "$tool/requirements.txt"
fi
done
================================================
FILE: local-llm
================================================
version: '3.8'
services:
backend:
volumes:
- "./:/app"
build: .
ports:
- "8001:8001"
depends_on:
- super__tgwui
- super__redis
- super__postgres
networks:
- super_network
celery:
volumes:
- "./:/app"
build:
context: .
dockerfile: DockerfileCelery
depends_on:
- super__tgwui
- super__redis
- super__postgres
networks:
- super_network
gui:
build: ./gui
ports:
- "3000:3000"
environment:
- NEXT_PUBLIC_API_BASE_URL=http://localhost:8001
networks:
- super_network
volumes:
- ./gui:/app
- /app/node_modules
- /app/.next
super__tgwui:
build:
context: .
dockerfile: ./tgwui/DockerfileTGWUI
container_name: super__tgwui
environment:
- EXTRA_LAUNCH_ARGS="--listen --verbose --extensions openai --threads 4 --n_ctx 1600"
ports:
- 7860:7860 # Default web port
- 5000:5000 # Default API port
- 5005:5005 # Default streaming port
- 5001:5001 # Default OpenAI API extension port
volumes:
- ./tgwui/config/loras:/app/loras
- ./tgwui/config/models:/app/models
- ./tgwui/config/presets:/app/presets
- ./tgwui/config/prompts:/app/prompts
- ./tgwui/config/softprompts:/app/softprompts
- ./tgwui/config/training:/app/training
logging:
driver: json-file
options:
max-file: "3" # number of files or file count
max-size: '10m'
networks:
- super_network
super__redis:
image: "docker.io/library/redis:latest"
networks:
- super_network
super__postgres:
image: "docker.io/library/postgres:latest"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
ports:
- "5432:5432"
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
================================================
FILE: local-llm-gpu
================================================
version: '3.8'
services:
backend:
volumes:
- "./:/app"
build: .
ports:
- "8001:8001"
depends_on:
- super__tgwui
- super__redis
- super__postgres
networks:
- super_network
celery:
volumes:
- "./:/app"
build:
context: .
dockerfile: DockerfileCelery
depends_on:
- super__tgwui
- super__redis
- super__postgres
networks:
- super_network
gui:
build: ./gui
ports:
- "3000:3000"
environment:
- NEXT_PUBLIC_API_BASE_URL=http://localhost:8001
networks:
- super_network
volumes:
- ./gui:/app
- /app/node_modules
- /app/.next
super__tgwui:
build:
context: ./tgwui/
target: llama-cublas
dockerfile: DockerfileTGWUI
# args:
# - LCL_SRC_DIR=text-generation-webui # Developers - see Dockerfile app_base
image: atinoda/text-generation-webui:llama-cublas # Specify variant as the :tag
container_name: super__tgwui
environment:
- EXTRA_LAUNCH_ARGS="--no-mmap --verbose --extensions openai --auto-devices --n_ctx 2000 --gpu-memory 22 22 --n-gpu-layers 128 --threads 8"
# - BUILD_EXTENSIONS_LIVE="silero_tts whisper_stt" # Install named extensions during every container launch. THIS WILL SIGNIFICANLTLY SLOW LAUNCH TIME.
ports:
- 7860:7860 # Default web port
- 5000:5000 # Default API port
- 5005:5005 # Default streaming port
- 5001:5001 # Default OpenAI API extension port
volumes:
- ./tgwui/config/loras:/app/loras
- ./tgwui/config/models:/app/models
- ./tgwui/config/presets:/app/presets
- ./tgwui/config/prompts:/app/prompts
- ./tgwui/config/softprompts:/app/softprompts
- ./tgwui/config/training:/app/training
# - ./config/extensions:/app/extensions
logging:
driver: json-file
options:
max-file: "3" # number of files or file count
max-size: '10m'
networks:
- super_network
deploy:
resources:
reservations:
devices:
- driver: nvidia
# count: "all"
device_ids: ['0', '1'] # must comment the above line if this line is uncommented.
capabilities: [gpu]
super__redis:
image: "docker.io/library/redis:latest"
networks:
- super_network
super__postgres:
image: "docker.io/library/postgres:latest"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
- POSTGRES_DB=super_agi_main
volumes:
- superagi_postgres_data:/var/lib/postgresql/data/
networks:
- super_network
ports:
- "5432:5432"
networks:
super_network:
driver: bridge
volumes:
superagi_postgres_data:
================================================
FILE: main.py
================================================
import requests
from fastapi import FastAPI, HTTPException, Depends, Request, status, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
from fastapi_sqlalchemy import DBSessionMiddleware, db
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import superagi
from datetime import timedelta, datetime
from superagi.agent.workflow_seed import IterationWorkflowSeed, AgentWorkflowSeed
from superagi.config.config import get_config
from superagi.controllers.agent import router as agent_router
from superagi.controllers.agent_execution import router as agent_execution_router
from superagi.controllers.agent_execution_feed import router as agent_execution_feed_router
from superagi.controllers.agent_execution_permission import router as agent_execution_permission_router
from superagi.controllers.agent_template import router as agent_template_router
from superagi.controllers.agent_workflow import router as agent_workflow_router
from superagi.controllers.budget import router as budget_router
from superagi.controllers.config import router as config_router
from superagi.controllers.organisation import router as organisation_router
from superagi.controllers.project import router as project_router
from superagi.controllers.twitter_oauth import router as twitter_oauth_router
from superagi.controllers.google_oauth import router as google_oauth_router
from superagi.controllers.resources import router as resources_router
from superagi.controllers.tool import router as tool_router
from superagi.controllers.tool_config import router as tool_config_router
from superagi.controllers.toolkit import router as toolkit_router
from superagi.controllers.user import router as user_router
from superagi.controllers.agent_execution_config import router as agent_execution_config
from superagi.controllers.analytics import router as analytics_router
from superagi.controllers.models_controller import router as models_controller_router
from superagi.controllers.knowledges import router as knowledges_router
from superagi.controllers.knowledge_configs import router as knowledge_configs_router
from superagi.controllers.vector_dbs import router as vector_dbs_router
from superagi.controllers.vector_db_indices import router as vector_db_indices_router
from superagi.controllers.marketplace_stats import router as marketplace_stats_router
from superagi.controllers.api_key import router as api_key_router
from superagi.controllers.api.agent import router as api_agent_router
from superagi.controllers.webhook import router as web_hook_router
from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.llm_model_factory import build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.agent_template import AgentTemplate
from superagi.models.models_config import ModelsConfig
from superagi.models.organisation import Organisation
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from superagi.models.user import User
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
from urllib.parse import urlparse
app = FastAPI()
db_host = get_config('DB_HOST', 'super__postgres')
db_url = get_config('DB_URL', None)
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
env = get_config('ENV', "DEV")
if db_url is None:
if db_username is None:
db_url = f'postgresql://{db_host}/{db_name}'
else:
db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
db_url = urlparse(db_url)
db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
engine = create_engine(db_url,
pool_size=20, # Maximum number of database connections in the pool
max_overflow=50, # Maximum number of connections that can be created beyond the pool_size
pool_timeout=30, # Timeout value in seconds for acquiring a connection from the pool
pool_recycle=1800, # Recycle connections after this number of seconds (optional)
pool_pre_ping=False, # Enable connection health checks (optional)
)
# app.add_middleware(DBSessionMiddleware, db_url=f'postgresql://{db_username}:{db_password}@localhost/{db_name}')
app.add_middleware(DBSessionMiddleware, db_url=db_url)
# Configure CORS middleware
origins = [
# Add more origins if needed
"*", # Allow all origins
]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Creating requrired tables -- Now handled using migrations
# DBBaseModel.metadata.create_all(bind=engine, checkfirst=True)
# DBBaseModel.metadata.drop_all(bind=engine,checkfirst=True)
app.include_router(user_router, prefix="/users")
app.include_router(tool_router, prefix="/tools")
app.include_router(organisation_router, prefix="/organisations")
app.include_router(project_router, prefix="/projects")
app.include_router(budget_router, prefix="/budgets")
app.include_router(agent_router, prefix="/agents")
app.include_router(agent_execution_router, prefix="/agentexecutions")
app.include_router(agent_execution_feed_router, prefix="/agentexecutionfeeds")
app.include_router(agent_execution_permission_router, prefix="/agentexecutionpermissions")
app.include_router(resources_router, prefix="/resources")
app.include_router(config_router, prefix="/configs")
app.include_router(toolkit_router, prefix="/toolkits")
app.include_router(tool_config_router, prefix="/tool_configs")
app.include_router(config_router, prefix="/configs")
app.include_router(agent_template_router, prefix="/agent_templates")
app.include_router(agent_workflow_router, prefix="/agent_workflows")
app.include_router(twitter_oauth_router, prefix="/twitter")
app.include_router(agent_execution_config, prefix="/agent_executions_configs")
app.include_router(analytics_router, prefix="/analytics")
app.include_router(models_controller_router, prefix="/models_controller")
app.include_router(google_oauth_router, prefix="/google")
app.include_router(knowledges_router, prefix="/knowledges")
app.include_router(knowledge_configs_router, prefix="/knowledge_configs")
app.include_router(vector_dbs_router, prefix="/vector_dbs")
app.include_router(vector_db_indices_router, prefix="/vector_db_indices")
app.include_router(marketplace_stats_router, prefix="/marketplace")
app.include_router(api_key_router, prefix="/api-keys")
app.include_router(api_agent_router,prefix="/v1/agent")
app.include_router(web_hook_router,prefix="/webhook")
# in production you can use Settings management
# from pydantic to get secret key from .env
class Settings(BaseModel):
# jwt_secret = get_config("JWT_SECRET_KEY")
authjwt_secret_key: str = superagi.config.config.get_config("JWT_SECRET_KEY")
def create_access_token(email, Authorize: AuthJWT = Depends()):
expiry_time_hours = superagi.config.config.get_config("JWT_EXPIRY")
if type(expiry_time_hours) == str:
expiry_time_hours = int(expiry_time_hours)
if expiry_time_hours is None:
expiry_time_hours = 200
expires = timedelta(hours=expiry_time_hours)
access_token = Authorize.create_access_token(subject=email, expires_time=expires)
return access_token
# callback to get your configuration
@AuthJWT.load_config
def get_config():
return Settings()
# exception handler for authjwt
# in production, you can tweak performance using orjson response
@app.exception_handler(AuthJWTException)
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.message}
)
def replace_old_iteration_workflows(session):
templates = session.query(AgentTemplate).all()
for template in templates:
iter_workflow = IterationWorkflow.find_by_id(session, template.agent_workflow_id)
if not iter_workflow:
continue
if iter_workflow.name == "Fixed Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Fixed Task Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
if iter_workflow.name == "Maintain Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Dynamic Task Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
if iter_workflow.name == "Don't Maintain Task Queue" or iter_workflow.name == "Goal Based Agent":
agent_workflow = AgentWorkflow.find_by_name(session, "Goal Based Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
@app.on_event("startup")
async def startup_event():
# Perform startup tasks here
logger.info("Running Startup tasks")
Session = sessionmaker(bind=engine)
session = Session()
default_user = session.query(User).filter(User.email == "super6@agi.com").first()
logger.info(default_user)
if default_user is not None:
organisation = session.query(Organisation).filter_by(id=default_user.organisation_id).first()
logger.info(organisation)
register_toolkits(session, organisation)
def register_toolkit_for_all_organisation():
organizations = session.query(Organisation).all()
for organization in organizations:
register_toolkits(session, organization)
logger.info("Successfully registered local toolkits for all Organisations!")
def register_toolkit_for_master_organisation():
marketplace_organisation_id = superagi.config.config.get_config("MARKETPLACE_ORGANISATION_ID")
marketplace_organisation = session.query(Organisation).filter(
Organisation.id == marketplace_organisation_id).first()
if marketplace_organisation is not None:
register_marketplace_toolkits(session, marketplace_organisation)
IterationWorkflowSeed.build_single_step_agent(session)
IterationWorkflowSeed.build_task_based_agents(session)
IterationWorkflowSeed.build_action_based_agents(session)
IterationWorkflowSeed.build_initialize_task_workflow(session)
AgentWorkflowSeed.build_goal_based_agent(session)
AgentWorkflowSeed.build_task_based_agent(session)
AgentWorkflowSeed.build_fixed_task_based_agent(session)
AgentWorkflowSeed.build_sales_workflow(session)
AgentWorkflowSeed.build_recruitment_workflow(session)
AgentWorkflowSeed.build_coding_workflow(session)
# NOTE: remove old workflows. Need to remove this changes later
workflows = ["Sales Engagement Workflow", "Recruitment Workflow", "SuperCoder", "Goal Based Workflow",
"Dynamic Task Workflow", "Fixed Task Workflow"]
workflows = session.query(AgentWorkflow).filter(AgentWorkflow.name.not_in(workflows))
for workflow in workflows:
session.delete(workflow)
# AgentWorkflowSeed.doc_search_and_code(session)
# AgentWorkflowSeed.build_research_email_workflow(session)
replace_old_iteration_workflows(session)
if env != "PROD":
register_toolkit_for_all_organisation()
else:
register_toolkit_for_master_organisation()
session.close()
@app.post('/login')
def login(request: LoginRequest, Authorize: AuthJWT = Depends()):
"""Login API for email and password based login"""
email_to_find = request.email
user: User = db.session.query(User).filter(User.email == email_to_find).first()
if user == None or request.email != user.email or request.password != user.password:
raise HTTPException(status_code=401, detail="Bad username or password")
# subject identifier for who this token is for example id or username from database
access_token = create_access_token(user.email, Authorize)
return {"access_token": access_token}
# def get_jwt_from_payload(user_email: str,Authorize: AuthJWT = Depends()):
# access_token = Authorize.create_access_token(subject=user_email)
# return access_token
@app.get('/github-login')
def github_login():
"""GitHub login"""
github_client_id = ""
return RedirectResponse(f'https://github.com/login/oauth/authorize?scope=user:email&client_id={github_client_id}')
@app.get('/github-auth')
def github_auth_handler(code: str = Query(...), Authorize: AuthJWT = Depends()):
"""GitHub login callback"""
github_token_url = 'https://github.com/login/oauth/access_token'
github_client_id = superagi.config.config.get_config("GITHUB_CLIENT_ID")
github_client_secret = superagi.config.config.get_config("GITHUB_CLIENT_SECRET")
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
params = {
'client_id': github_client_id,
'client_secret': github_client_secret,
'code': code
}
headers = {
'Accept': 'application/json'
}
response = requests.post(github_token_url, params=params, headers=headers)
if response.ok:
data = response.json()
access_token = data.get('access_token')
github_api_url = 'https://api.github.com/user'
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(github_api_url, headers=headers)
if response.ok:
user_data = response.json()
user_email = user_data["email"]
if user_email is None:
user_email = user_data["login"] + "@github.com"
db_user: User = db.session.query(User).filter(User.email == user_email).first()
if db_user is not None:
jwt_token = create_access_token(user_email, Authorize)
redirect_url_success = f"{frontend_url}?access_token={jwt_token}&first_time_login={False}"
return RedirectResponse(url=redirect_url_success)
user = User(name=user_data["name"], email=user_email)
db.session.add(user)
db.session.commit()
jwt_token = create_access_token(user_email, Authorize)
redirect_url_success = f"{frontend_url}?access_token={jwt_token}&first_time_login={True}"
return RedirectResponse(url=redirect_url_success)
else:
redirect_url_failure = "https://superagi.com/"
return RedirectResponse(url=redirect_url_failure)
else:
redirect_url_failure = "https://superagi.com/"
return RedirectResponse(url=redirect_url_failure)
@app.get('/user')
def user(Authorize: AuthJWT = Depends()):
"""API to get current logged in User"""
Authorize.jwt_required()
current_user = Authorize.get_jwt_subject()
return {"user": current_user}
@app.get("/validate-access-token")
async def root(Authorize: AuthJWT = Depends()):
"""API to validate access token"""
try:
Authorize.jwt_required()
current_user_email = Authorize.get_jwt_subject()
current_user = db.session.query(User).filter(User.email == current_user_email).first()
return current_user
except:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
model = build_model_with_api_key(source, api_key)
valid_api_key = model.verify_access_key() if model is not None else False
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
return {"message": "Invalid API Key", "status": "failed"}
@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
try:
llm = OpenAi(api_key=open_ai_key)
response = llm.chat_completion([{"role": "system", "content": "Hey!"}])
except:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
# #Unprotected route
@app.get("/hello/{name}")
async def say_hello(name: str, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
return {"message": f"Hello {name}"}
@app.get('/get/github_client_id')
def github_client_id():
"""Get GitHub Client ID"""
git_hub_client_id = superagi.config.config.get_config("GITHUB_CLIENT_ID")
if git_hub_client_id:
git_hub_client_id = git_hub_client_id.strip()
return {"github_client_id": git_hub_client_id}
# # __________________TO RUN____________________________
# # uvicorn main:app --host 0.0.0.0 --port 8001 --reload
================================================
FILE: migrations/README
================================================
Generic single-database configuration.
================================================
FILE: migrations/env.py
================================================
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from urllib.parse import urlparse
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from superagi.models.base_model import DBBaseModel
target_metadata = DBBaseModel.metadata
from superagi.models import *
from superagi.config.config import get_config
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
db_host = get_config('DB_HOST', 'super__postgres')
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
database_url = get_config('DB_URL', None)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
db_url = database_url
if db_url is None:
if db_username is None:
db_url = f'postgresql://{db_host}/{db_name}'
else:
db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
db_url = urlparse(db_url)
db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
config.set_main_option("sqlalchemy.url", db_url)
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
db_host = get_config('DB_HOST', 'super__postgres')
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
db_url = get_config('DB_URL', None)
if db_url is None:
if db_username is None:
db_url = f'postgresql://{db_host}/{db_name}'
else:
db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
db_url = urlparse(db_url)
db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
config.set_main_option('sqlalchemy.url', db_url)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
================================================
FILE: migrations/script.py.mako
================================================
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
================================================
FILE: migrations/versions/1d54db311055_add_permissions.py
================================================
"""add permissions
Revision ID: 1d54db311055
Revises: 3356a2f89a33
Create Date: 2023-06-14 11:05:59.678961
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1d54db311055'
down_revision = '516ecc1c723d'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_execution_permissions',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_execution_id', sa.Integer(), nullable=True),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('tool_name', sa.String(), nullable=True),
sa.Column('user_feedback', sa.Text(), nullable=True),
sa.Column('assistant_reply', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.add_column('agent_executions', sa.Column('permission_id', sa.Integer(), nullable=True))
# index on agent_execution_id
op.create_index(op.f('ix_agent_execution_permissions_agent_execution_id')
, 'agent_execution_permissions', ['agent_execution_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agent_executions', 'permission_id')
op.drop_table('agent_execution_permissions')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/2cc1179834b0_agent_executions_modified.py
================================================
"""agent_executions_modified
Revision ID: 2cc1179834b0
Revises: 2f97c068fab9
Create Date: 2023-06-02 21:01:43.303961
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2cc1179834b0'
down_revision = '2f97c068fab9'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agent_executions', sa.Column('calls', sa.Integer(), nullable=True))
op.add_column('agent_executions', sa.Column('tokens', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agent_executions', 'tokens')
op.drop_column('agent_executions', 'calls')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/2f97c068fab9_resource_modified.py
================================================
"""Resource Modified
Revision ID: 2f97c068fab9
Revises: a91808a89623
Create Date: 2023-06-02 13:13:21.670935
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2f97c068fab9'
down_revision = 'a91808a89623'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('resources', sa.Column('agent_id', sa.Integer(), nullable=True))
op.drop_column('resources', 'project_id')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('resources', sa.Column('project_id', sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column('resources', 'agent_id')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/2fbd6472112c_add_feed_group_id_to_execution_and_feed.py
================================================
"""add feed group id to execution and feed
Revision ID: 2fbd6472112c
Revises: 5184645e9f12
Create Date: 2023-08-01 17:09:16.183863
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2fbd6472112c'
down_revision = '5184645e9f12'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('agent_executions',
sa.Column('current_feed_group_id', sa.String(), nullable=True, server_default="DEFAULT"))
op.add_column('agent_execution_feeds', sa.Column('feed_group_id', sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column('agent_executions', 'current_feed_group_id')
op.drop_column('agent_execution_feeds', 'feed_group_id')
================================================
FILE: migrations/versions/3356a2f89a33_added_configurations_table.py
================================================
"""added_configurations_table
Revision ID: 3356a2f89a33
Revises: 35e47f20475b
Create Date: 2023-06-06 10:51:15.111738
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3356a2f89a33'
down_revision = '35e47f20475b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('configurations',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.drop_index('ix_aea_step_id', table_name='agent_executions')
op.drop_index('ix_ats_unique_id', table_name='agent_template_steps')
op.drop_index('ix_at_name', table_name='agent_templates')
op.drop_index('ix_agents_agnt_template_id', table_name='agents')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index('ix_agents_agnt_template_id', 'agents', ['agent_template_id'], unique=False)
op.create_index('ix_at_name', 'agent_templates', ['name'], unique=False)
op.create_index('ix_ats_unique_id', 'agent_template_steps', ['unique_id'], unique=False)
op.create_index('ix_aea_step_id', 'agent_executions', ['current_step_id'], unique=False)
op.drop_table('configurations')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/35e47f20475b_renamed_tokens_calls.py
================================================
"""renamed_tokens_calls
Revision ID: 35e47f20475b
Revises: 598cfb37292a
Create Date: 2023-06-06 04:34:15.101672
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '35e47f20475b'
down_revision = '598cfb37292a'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agent_executions', sa.Column('num_of_calls', sa.Integer(), nullable=True))
op.add_column('agent_executions', sa.Column('num_of_tokens', sa.Integer(), nullable=True))
op.drop_column('agent_executions', 'calls')
op.drop_column('agent_executions', 'tokens')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agent_executions', sa.Column('tokens', sa.INTEGER(), autoincrement=False, nullable=True))
op.add_column('agent_executions', sa.Column('calls', sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column('agent_executions', 'num_of_tokens')
op.drop_column('agent_executions', 'num_of_calls')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/3867bb00a495_added_first_login_source.py
================================================
"""added_first_login_source
Revision ID: 3867bb00a495
Revises: 661ec8a4c32e
Create Date: 2023-09-15 02:06:24.006555
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3867bb00a495'
down_revision = '661ec8a4c32e'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('first_login_source', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'first_login_source')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py
================================================
"""add filter colume in webhooks
Revision ID: 40affbf3022b
Revises: 5d5f801f28e7
Create Date: 2023-08-28 12:30:35.171176
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '40affbf3022b'
down_revision = '5d5f801f28e7'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('webhooks', sa.Column('filters', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('webhooks', 'filters')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/446884dcae58_add_api_key_and_web_hook.py
================================================
"""add api_key and web_hook
Revision ID: 446884dcae58
Revises: 71e3980d55f5
Create Date: 2023-07-29 10:55:21.714245
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '446884dcae58'
down_revision = '2fbd6472112c'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('api_keys',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('org_id', sa.Integer(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('is_expired',sa.Boolean(),nullable=True,default=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('webhooks',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('org_id', sa.Integer(), nullable=True),
sa.Column('url', sa.String(), nullable=True),
sa.Column('headers', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('is_deleted',sa.Boolean(),nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('webhook_events',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('run_id', sa.Integer(), nullable=True),
sa.Column('event', sa.String(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('errors', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
#add index *********************
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('webhooks')
op.drop_table('api_keys')
op.drop_table('webhook_events')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/44b0d6f2d1b3_init_models.py
================================================
"""init models
Revision ID: 44b0d6f2d1b3
Revises:
Create Date: 2023-06-01 11:55:35.195423
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '44b0d6f2d1b3'
down_revision = None
branch_labels = None
depends_on = None
from sqlalchemy.engine.reflection import Inspector
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
tables = inspector.get_table_names()
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if 'agent_configurations' not in tables:
op.create_table('agent_configurations',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'agent_execution_feeds' not in tables:
op.create_table('agent_execution_feeds',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_execution_id', sa.Integer(), nullable=True),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('feed', sa.Text(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'agent_executions' not in tables:
op.create_table('agent_executions',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('status', sa.String(), nullable=True),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('last_execution_time', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'agents' not in tables:
op.create_table('agents',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('project_id', sa.Integer(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'budgets' not in tables:
op.create_table('budgets',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('budget', sa.Float(), nullable=True),
sa.Column('cycle', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'organisations' not in tables:
op.create_table('organisations',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'projects' not in tables:
op.create_table('projects',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'tool_configs' not in tables:
op.create_table('tool_configs',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.String(), nullable=True),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'tools' not in tables:
op.create_table('tools',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('folder_name', sa.String(), nullable=True),
sa.Column('class_name', sa.String(), nullable=True),
sa.Column('file_name', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if 'users' not in tables:
op.create_table('users',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('email')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('users')
op.drop_table('tools')
op.drop_table('tool_configs')
op.drop_table('projects')
op.drop_table('organisations')
op.drop_table('budgets')
op.drop_table('agents')
op.drop_table('agent_executions')
op.drop_table('agent_execution_feeds')
op.drop_table('agent_configurations')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/467e85d5e1cd_updated_resources_added_exec_id.py
================================================
"""updated_resources_added_exec_id
Revision ID: 467e85d5e1cd
Revises: ba60b12ae109
Create Date: 2023-07-10 08:54:46.702652
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '467e85d5e1cd'
down_revision = 'ba60b12ae109'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('resources', sa.Column('agent_execution_id', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('resources', 'agent_execution_id')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/516ecc1c723d_adding_marketplace_template_id_to_agent_.py
================================================
"""adding marketplace_template_id to agent tempaltes
Revision ID: 516ecc1c723d
Revises: 8962bed0d809
Create Date: 2023-06-13 17:10:06.262764
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '516ecc1c723d'
down_revision = '8962bed0d809'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('agent_templates', sa.Column('marketplace_template_id', sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column('agent_templates', sa.Column('marketplace_template_id', sa.Integer(), nullable=True))
================================================
FILE: migrations/versions/5184645e9f12_add_question_to_agent_execution_.py
================================================
"""add question to agent execution permission
Revision ID: 5184645e9f12
Revises: 9419b3340af7
Create Date: 2023-07-21 08:16:14.702389
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '5184645e9f12'
down_revision = '9419b3340af7'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('agent_execution_permissions', sa.Column('question', sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column('agent_execution_permissions', "question")
================================================
FILE: migrations/versions/520aa6776347_create_models_config.py
================================================
"""create models config
Revision ID: 520aa6776347
Revises: 71e3980d55f5
Create Date: 2023-08-01 07:48:13.724938
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '520aa6776347'
down_revision = '446884dcae58'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('models_config',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('provider', sa.String(), nullable=False),
sa.Column('api_key', sa.String(), nullable=False),
sa.Column('org_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('models_config')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/598cfb37292a_adding_agent_templates.py
================================================
"""adding agent templates
Revision ID: 598cfb37292a
Revises: 2f97c068fab9
Create Date: 2023-06-05 12:44:30.982492
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine import Inspector
# revision identifiers, used by Alembic.
revision = '598cfb37292a'
down_revision = '2cc1179834b0'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_template_steps',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_template_id', sa.Integer(), nullable=True),
sa.Column('unique_id', sa.String(), nullable=True),
sa.Column('prompt', sa.Text(), nullable=True),
sa.Column('variables', sa.Text(), nullable=True),
sa.Column('output_type', sa.String(), nullable=True),
sa.Column('step_type', sa.String(), nullable=True),
sa.Column('next_step_id', sa.Integer(), nullable=True),
sa.Column('history_enabled', sa.Boolean(), nullable=True),
sa.Column('completion_prompt', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('agent_templates',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.add_column('agent_executions', sa.Column('current_step_id', sa.Integer()))
op.add_column('agents', sa.Column('agent_template_id', sa.Integer()))
op.create_index("ix_agents_agnt_template_id", "agents", ['agent_template_id'])
op.create_index("ix_aea_step_id", "agent_executions", ['current_step_id'])
op.create_index("ix_ats_unique_id", "agent_template_steps", ['unique_id'])
op.create_index("ix_at_name", "agent_templates", ['name'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agents', 'agent_template_id')
op.add_column('agent_executions', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_column('agent_executions', 'current_step_id')
op.drop_table('agent_templates')
op.drop_table('agent_template_steps')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/5d5f801f28e7_create_model_table.py
================================================
"""create model table
Revision ID: 5d5f801f28e7
Revises: 520aa6776347
Create Date: 2023-08-07 05:36:29.791610
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '5d5f801f28e7'
down_revision = 'be1d922bf2ad'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('models',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('model_name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('end_point', sa.String(), nullable=False),
sa.Column('model_provider_id', sa.Integer(), nullable=False),
sa.Column('token_limit', sa.Integer(), nullable=False),
sa.Column('type', sa.String(), nullable=False),
sa.Column('version', sa.String(), nullable=False),
sa.Column('org_id', sa.Integer(), nullable=False),
sa.Column('model_features', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('models')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/661ec8a4c32e_open_ai_error_handling.py
================================================
"""open_ai_error_handling
Revision ID: 661ec8a4c32e
Revises: 40affbf3022b
Create Date: 2023-09-07 10:41:07.462436
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '661ec8a4c32e'
down_revision = 'c4f2f6ba602a'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agent_execution_feeds', sa.Column('error_message', sa.String(), nullable=True))
op.add_column('agent_executions', sa.Column('last_shown_error_id', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agent_executions', 'last_shown_error_id')
op.drop_column('agent_execution_feeds', 'error_message')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/71e3980d55f5_knowledge_and_vector_dbs.py
================================================
"""Knowledge and Vector dbs
Revision ID: 71e3980d55f5
Revises: cac478732572
Create Date: 2023-07-26 07:18:06.492832
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '71e3980d55f5'
down_revision = 'cac478732572'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('knowledge_configs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('knowledge_id', sa.Integer(), nullable=False),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('knowledges',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('vector_db_index_id', sa.Integer(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('contributed_by', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('marketplace_stats',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('reference_id', sa.Integer(), nullable=True),
sa.Column('reference_name', sa.String(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('vector_db_configs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('vector_db_id', sa.Integer(), nullable=False),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('vector_db_indices',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('vector_db_id', sa.Integer(), nullable=True),
sa.Column('dimensions', sa.Integer(), nullable=True),
sa.Column('state', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('vector_dbs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('db_type', sa.String(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('vector_dbs')
op.drop_table('vector_db_indices')
op.drop_table('vector_db_configs')
op.drop_table('knowledges')
op.drop_table('knowledge_configs')
================================================
FILE: migrations/versions/7a3e336c0fba_added_tools_related_models.py
================================================
"""added_tools_related_models
Revision ID: 7a3e336c0fba
Revises: 516ecc1c723d
Create Date: 2023-06-18 11:05:35.801505
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7a3e336c0fba'
down_revision = '1d54db311055'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('toolkits',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.Column('show_toolkit', sa.Boolean(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('tool_code_link', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.add_column('tool_configs', sa.Column('toolkit_id', sa.Integer(), nullable=True))
op.drop_column('tool_configs', 'name')
op.drop_column('tool_configs', 'agent_id')
op.add_column('tools', sa.Column('description', sa.String(), nullable=True))
op.add_column('tools', sa.Column('toolkit_id', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('tools', 'toolkit_id')
op.drop_column('tools', 'description')
op.add_column('tool_configs', sa.Column('agent_id', sa.INTEGER(), autoincrement=False, nullable=True))
op.add_column('tool_configs', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_column('tool_configs', 'toolkit_id')
op.drop_table('toolkits')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/83424de1347e_added_agent_execution_config.py
================================================
"""added_agent_execution_config
Revision ID: 83424de1347e
Revises: c02f3d759bf3
Create Date: 2023-07-03 22:42:50.091762
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '83424de1347e'
down_revision = 'c02f3d759bf3'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_execution_configs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_execution_id', sa.Integer(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('agent_execution_configs')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/8962bed0d809_creating_agent_templates.py
================================================
"""creating agent templates
Revision ID: 8962bed0d809
Revises: d9b3436197eb
Create Date: 2023-06-10 15:40:08.942612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '8962bed0d809'
down_revision = 'd9b3436197eb'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table('agent_templates',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('agent_workflow_id', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('agent_template_configs',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_template_id', sa.Integer(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index("ix_atc_agnt_template_id_key", "agent_template_configs", ['agent_template_id', 'key'])
op.create_index("ix_agt_agnt_organisation_id", "agent_templates", ['organisation_id'])
op.create_index("ix_agt_agnt_workflow_id", "agent_templates", ['agent_workflow_id'])
op.create_index("ix_agt_agnt_name", "agent_templates", ['name'])
def downgrade() -> None:
op.drop_table('agent_template_configs')
op.drop_table('agent_templates')
================================================
FILE: migrations/versions/9270eb5a8475_local_llms.py
================================================
"""local_llms
Revision ID: 9270eb5a8475
Revises: 3867bb00a495
Create Date: 2023-10-04 09:26:33.865424
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9270eb5a8475'
down_revision = '3867bb00a495'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('models', 'context_length')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/9419b3340af7_create_agent_workflow.py
================================================
"""create agent workflow
Revision ID: 9419b3340af7
Revises: fe234ea6e9bc
Create Date: 2023-07-18 16:46:03.497943
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '9419b3340af7'
down_revision = 'fe234ea6e9bc'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table('agent_workflows',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('agent_workflow_steps',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('step_type', sa.String(), nullable=False),
sa.Column('agent_workflow_id', sa.Integer(), nullable=True),
sa.Column('action_reference_id', sa.Integer(), nullable=True),
sa.Column('action_type', sa.String(), nullable=True),
sa.Column('unique_id', sa.String(), nullable=False),
sa.Column('next_steps', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('agent_workflow_step_tools',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('unique_id', sa.String(), nullable=True),
sa.Column('tool_name', sa.String(), nullable=True),
sa.Column('input_instruction', sa.Text(), nullable=True),
sa.Column('output_instruction', sa.Text(), nullable=True),
sa.Column('history_enabled', sa.Boolean(), nullable=True),
sa.Column('completion_prompt', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
def downgrade() -> None:
op.drop_table('agent_workflows')
op.drop_table('agent_workflow_steps')
op.drop_table('agent_workflow_step_tools')
================================================
FILE: migrations/versions/a91808a89623_added_resources.py
================================================
"""added resources
Revision ID: a91808a89623
Revises: 44b0d6f2d1b3
Create Date: 2023-06-01 07:00:33.982485
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a91808a89623'
down_revision = '44b0d6f2d1b3'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('resources',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('storage_type', sa.String(), nullable=True),
sa.Column('path', sa.String(), nullable=True),
sa.Column('size', sa.Integer(), nullable=True),
sa.Column('type', sa.String(), nullable=True),
sa.Column('channel', sa.String(), nullable=True),
sa.Column('project_id', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.add_column('agent_execution_feeds', sa.Column('extra_info', sa.String(), nullable=True))
op.add_column('agent_executions', sa.Column('name', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('agent_executions', 'name')
op.drop_column('agent_execution_feeds', 'extra_info')
op.drop_table('resources')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/ba60b12ae109_create_agent_scheduler.py
================================================
"""create_agent_scheduler
Revision ID: ba60b12ae109
Revises: 83424de1347e
Create Date: 2023-07-04 10:58:37.991063
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ba60b12ae109'
down_revision = '83424de1347e'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_schedule',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('start_time', sa.DateTime(), nullable=True),
sa.Column('next_scheduled_time', sa.DateTime(), nullable=True),
sa.Column('recurrence_interval', sa.String(), nullable=True),
sa.Column('expiry_date', sa.DateTime(), nullable=True),
sa.Column('expiry_runs', sa.Integer(), nullable=True),
sa.Column('current_runs', sa.Integer(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_agent_schedule_expiry_date'), 'agent_schedule', ['expiry_date'], unique=False)
op.create_index(op.f('ix_agent_schedule_status'), 'agent_schedule', ['status'], unique=False)
op.create_index(op.f('ix_agent_schedule_agent_id'), 'agent_schedule', ['agent_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_agent_schedule_agent_id'), table_name='agent_schedule')
op.drop_index(op.f('ix_agent_schedule_status'), table_name='agent_schedule')
op.drop_index(op.f('ix_agent_schedule_expiry_date'), table_name='agent_schedule')
op.drop_table('agent_schedule')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/be1d922bf2ad_create_call_logs_table.py
================================================
"""create call logs table
Revision ID: be1d922bf2ad
Revises: 2fbd6472112c
Create Date: 2023-08-08 08:42:37.148178
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'be1d922bf2ad'
down_revision = '520aa6776347'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('call_logs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_execution_name', sa.String(), nullable=False),
sa.Column('agent_id', sa.Integer(), nullable=False),
sa.Column('tokens_consumed', sa.Integer(), nullable=False),
sa.Column('tool_used', sa.String(), nullable=False),
sa.Column('model', sa.String(), nullable=True),
sa.Column('org_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('call_logs')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/c02f3d759bf3_add_summary_to_resource.py
================================================
"""add summary to resource
Revision ID: c02f3d759bf3
Revises: 1d54db311055
Create Date: 2023-06-27 05:07:29.016704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'c02f3d759bf3'
down_revision = 'c5c19944c90c'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ##
op.add_column('resources', sa.Column('summary', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('resources', 'summary')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/c4f2f6ba602a_agent_workflow_wait_step.py
================================================
"""agent_workflow_wait_step
Revision ID: c4f2f6ba602a
Revises: 40affbf3022b
Create Date: 2023-09-04 05:34:10.195248
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'c4f2f6ba602a'
down_revision = '40affbf3022b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_workflow_step_waits',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.Column('unique_id', sa.String(), nullable=True),
sa.Column('delay', sa.Integer(), nullable=True),
sa.Column('wait_begin_time', sa.DateTime(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('agent_workflow_step_waits')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/c5c19944c90c_create_oauth_tokens.py
================================================
"""Create Oauth Tokens
Revision ID: c5c19944c90c
Revises: 7a3e336c0fba
Create Date: 2023-06-30 07:26:29.180784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'c5c19944c90c'
down_revision = '7a3e336c0fba'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('oauth_tokens',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('organisation_id', sa.Integer(), nullable=True),
sa.Column('toolkit_id', sa.Integer(), nullable=True),
sa.Column('key', sa.String(), nullable=True),
sa.Column('value', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.drop_index('ix_agent_execution_permissions_agent_execution_id', table_name='agent_execution_permissions')
op.drop_index('ix_atc_agnt_template_id_key', table_name='agent_template_configs')
op.drop_index('ix_agt_agnt_name', table_name='agent_templates')
op.drop_index('ix_agt_agnt_organisation_id', table_name='agent_templates')
op.drop_index('ix_agt_agnt_workflow_id', table_name='agent_templates')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index('ix_agt_agnt_workflow_id', 'agent_templates', ['agent_workflow_id'], unique=False)
op.create_index('ix_agt_agnt_organisation_id', 'agent_templates', ['organisation_id'], unique=False)
op.create_index('ix_agt_agnt_name', 'agent_templates', ['name'], unique=False)
op.create_index('ix_atc_agnt_template_id_key', 'agent_template_configs', ['agent_template_id', 'key'], unique=False)
op.create_index('ix_agent_execution_permissions_agent_execution_id', 'agent_execution_permissions', ['agent_execution_id'], unique=False)
op.drop_table('oauth_tokens')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/cac478732572_delete_agent_feature.py
================================================
"""delete_agent_feature
Revision ID: cac478732572
Revises: e39295ec089c
Create Date: 2023-07-13 17:18:42.003412
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'cac478732572'
down_revision = 'e39295ec089c'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('agents', sa.Column('is_deleted', sa.Boolean(), nullable=True, server_default=sa.false()))
def downgrade() -> None:
op.drop_column('agents', 'is_deleted')
================================================
FILE: migrations/versions/d8315244ea43_updated_tool_configs.py
================================================
"""updated_tool_configs
Revision ID: d8315244ea43
Revises: 71e3980d55f5
Create Date: 2023-08-01 11:11:32.725355
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd8315244ea43'
down_revision = '71e3980d55f5'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('tool_configs', sa.Column('key_type', sa.String(), nullable=True))
op.add_column('tool_configs', sa.Column('is_secret', sa.Boolean(), nullable=True))
op.add_column('tool_configs', sa.Column('is_required', sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('tool_configs', 'is_required')
op.drop_column('tool_configs', 'is_secret')
op.drop_column('tool_configs', 'key_type')
# ### end Alembic commands ###
================================================
FILE: migrations/versions/d9b3436197eb_renaming_templates.py
================================================
"""renaming templates
Revision ID: d9b3436197eb
Revises: 3356a2f89a33
Create Date: 2023-06-10 09:28:28.262705
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd9b3436197eb'
down_revision = '3356a2f89a33'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.rename_table('agent_templates', 'agent_workflows')
op.rename_table('agent_template_steps', 'agent_workflow_steps')
with op.batch_alter_table('agent_workflow_steps') as bop:
bop.alter_column('agent_template_id', new_column_name='agent_workflow_id')
with op.batch_alter_table('agents') as bop:
bop.alter_column('agent_template_id', new_column_name='agent_workflow_id')
def downgrade() -> None:
op.rename_table('agent_workflows', 'agent_templates')
op.rename_table('agent_workflow_steps', 'agent_template_steps')
with op.batch_alter_table('agent_templates') as bop:
bop.alter_column('agent_workflow_id', new_column_name='agent_template_id')
with op.batch_alter_table('agents') as bop:
bop.alter_column('agent_workflow_id', new_column_name='agent_template_id')
================================================
FILE: migrations/versions/e39295ec089c_creating_events.py
================================================
"""creating events
Revision ID: e39295ec089c
Revises: 7a3e336c0fba
Create Date: 2023-06-30 12:23:12.269999
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'e39295ec089c'
down_revision = '467e85d5e1cd'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table('events',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('event_name', sa.String(), nullable=False),
sa.Column('event_value', sa.Integer(), nullable=False),
sa.Column('event_property', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('org_id', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Use naming convention similar to the reference code for the index creation
op.create_index(op.f('ix_events_agent_id'), 'events', ['agent_id'], unique=False)
op.create_index(op.f('ix_events_org_id'), 'events', ['org_id'], unique=False)
op.create_index(op.f('ix_events_event_property'), 'events', ['event_property'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_events_event_property'), table_name='events')
op.drop_index(op.f('ix_events_org_id'), table_name='events')
op.drop_index(op.f('ix_events_agent_id'), table_name='events')
op.drop_table('events')
================================================
FILE: migrations/versions/fe234ea6e9bc_modify_agent_workflow_tables.py
================================================
"""update agent workflow tables
Revision ID: fe234ea6e9bc
Revises: d8315244ea43
Create Date: 2023-07-18 16:46:29.305378
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fe234ea6e9bc'
down_revision = 'd8315244ea43'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.rename_table('agent_workflows', 'iteration_workflows')
op.rename_table('agent_workflow_steps', 'iteration_workflow_steps')
with op.batch_alter_table('iteration_workflow_steps') as bop:
bop.alter_column('agent_workflow_id', new_column_name='iteration_workflow_id')
with op.batch_alter_table('agent_executions') as bop:
bop.alter_column('current_step_id', new_column_name='current_agent_step_id')
op.add_column('agent_executions', sa.Column('iteration_workflow_step_id', sa.Integer(), nullable=True))
op.add_column('iteration_workflows',
sa.Column('has_task_queue', sa.Boolean(), nullable=True, server_default=sa.false()))
def downgrade() -> None:
op.rename_table('iteration_workflows', 'agent_workflows')
op.rename_table('iteration_workflow_steps', 'agent_workflow_steps')
op.drop_column('agent_executions', 'iteration_workflow_step_id')
op.drop_column('agent_workflows', 'has_task_queue')
================================================
FILE: nginx/default.conf
================================================
server {
listen 80;
location / {
proxy_pass http://gui:3000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location /api {
proxy_pass http://backend:8001;
client_max_body_size 50M;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
rewrite ^/api/(.*) /$1 break;
}
location /_next/webpack-hmr {
proxy_pass http://gui:3000;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_cache_bypass $http_upgrade;
}
}
================================================
FILE: package.json
================================================
{
"dependencies": {
"axios": "^1.4.0",
"react-toastify": "^9.1.3"
}
}
================================================
FILE: requirements.txt
================================================
aiohttp==3.8.4
aiosignal==1.3.1
alembic==1.11.1
amqp==5.1.1
anyio==3.7.0
apiclient==1.0.4
appdirs==1.4.4
async-timeout==4.0.2
attrs==23.1.0
beautifulsoup4==4.12.2
billiard==3.6.4.0
boto3==1.26.146
botocore==1.29.146
bs4==0.0.1
celery==5.2.7
certifi==2023.5.7
cffi==1.15.1
charset-normalizer==3.1.0
click==8.1.3
click-didyoumean==0.3.0
click-plugins==1.1.1
click-repl==0.2.0
colorama==0.4.6
confluent-kafka==2.1.1
cryptography==41.0.1
cssselect==1.2.0
chromadb==0.3.26
dataclasses-json==0.5.7
defusedxml==0.7.1
docx2txt==0.8
dnspython==2.3.0
email-validator==2.0.0.post2
exceptiongroup==1.1.1
fake-useragent==1.1.3
fastapi==0.95.2
fastapi-jwt-auth==0.5.0
FastAPI-SQLAlchemy==0.2.1
feedfinder2==0.0.4
feedparser==6.0.10
filelock==3.12.0
frozenlist==1.3.3
google-search-results==2.4.2
google-serp-api==1.0.3
google-api-core==2.11.0
google-api-python-client==2.88.0
google-auth==2.19.1
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
greenlet==2.0.2
h11==0.14.0
halo==0.0.31
httpcore==0.17.2
httptools==0.5.0
httpx==0.24.1
idna==3.4
importlib-metadata==6.6.0
importlib-resources==5.12.0
itsdangerous==2.1.2
jieba3k==0.35.1
Jinja2==3.1.2
jira==3.5.0
jmespath==1.0.1
joblib==1.2.0
json5==0.9.14
jsonmerge==1.9.0
jsonschema==4.17.3
kombu==5.2.4
llama-index==0.6.35
log-symbols==0.0.14
loguru==0.7.0
lxml==4.9.2
Mako==1.2.4
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
multidict==6.0.4
mypy-extensions==1.0.0
newspaper3k==0.2.8
nltk==3.8.1
numexpr==2.8.4
numpy==1.24.3
oauthlib==3.2.2
oauth2client==4.1.3
openai==0.27.7
openapi-schema-pydantic==1.2.4
orjson==3.8.14
packaging==23.1
parse==1.19.0
Pillow==9.5.0
pinecone-client==2.2.1
prompt-toolkit==3.0.38
psycopg2==2.9.6
pycparser==2.21
pydantic==1.10.8
PyJWT==1.7.1
PyPDF2==3.0.1
pyquery==2.0.0
pyrsistent==0.19.3
pytest==7.3.2
python-dateutil==2.8.2
python-dotenv==1.0.0
python-multipart==0.0.6
pytz==2023.3
PyYAML==6.0
qdrant-client==1.3.1
redis==4.5.5
regex==2023.5.5
replicate==0.8.4
requests==2.31.0
requests-file==1.5.1
requests-html==0.10.0
requests-oauthlib==1.3.1
requests-toolbelt==1.0.0
s3transfer==0.6.1
safetensors==0.3.2
sgmllib3k==1.0.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
spinners==0.0.24
starlette==0.27.0
SQLAlchemy==2.0.16
tenacity==8.2.2
termcolor==2.3.0
tiktoken==0.4.0
tinysegmenter==0.3
tldextract==3.4.4
tqdm==4.65.0
tweepy==4.14.0
typing-inspect==0.8.0
ujson==5.7.0
urllib3==1.26.16
uvicorn==0.22.0
vine==5.0.0
w3lib==2.1.1
watchfiles==0.19.0
wcwidth==0.2.6
weaviate-client==3.20.1
websockets==10.4
yarl==1.9.2
zipp==3.15.0
tiktoken==0.4.0
psycopg2==2.9.6
slack-sdk==3.21.3
pytest==7.3.2
pylint==2.17.4
pre-commit==3.3.3
pytest-cov==4.1.0
pytest-mock==3.11.1
transformers==4.30.2
pypdf==3.11.0
python-pptx==0.6.21
Pillow==9.5.0
EbookLib==0.18
html2text==2020.1.16
duckduckgo-search==3.8.3
google-generativeai==0.1.0
unstructured==0.8.1
ai21==1.2.6
typing-extensions==4.5.0
llama_cpp_python==0.2.7
================================================
FILE: run.bat
================================================
@echo off
echo Checking if config.yaml file exists...
if not exist config.yaml (
echo ERROR: config.yaml file not found. Please create the config.yaml file.
exit /b 1
)
echo Checking if virtual environment is activated...
if not defined VIRTUAL_ENV (
echo Virtual environment not activated. Creating and activating virtual environment...
python3 -m venv venv
if errorlevel 1 (
echo Error: Failed to create virtual environment.
exit /b 1
)
call venv\Scripts\activate.bat
) else (
echo Virtual environment is already activated.
)
echo Checking requirements...
pip show -r requirements.txt >nul 2>&1
if errorlevel 1 (
echo Installing requirements...
pip install -r requirements.txt >nul 2>&1
) else (
echo All packages are already installed.
)
echo Running test.py with python...
python test.py
if errorlevel 1 (
echo Running test.py with python3...
python3 test.py
)
================================================
FILE: run.sh
================================================
#!/bin/bash
# Check if config.yaml file exists
if [ ! -f "config.yaml" ]; then
echo "ERROR: config.yaml file not found. Please create the config.yaml file."
exit 1
fi
if [ ! -f "tgwui/text-generation-webui" ]; then
echo "Downloading tgwui src"
git clone https://github.com/oobabooga/text-generation-webui
mv text-generation-webui tgwui
fi
# Function to check if virtual environment is activated
is_venv_activated() {
[[ -n "$VIRTUAL_ENV" ]]
}
# Check if virtual environment is activated
if ! is_venv_activated; then
echo "Virtual environment not activated. Creating and activating virtual environment..."
# Create virtual environment
python3 -m venv venv
# Activate virtual environment based on the operating system
if [[ "$OSTYPE" == "darwin"* ]]; then
source venv/bin/activate
else
source venv/bin/activate
fi
else
echo "Virtual environment is already activated."
fi
# Activate virtual environment
if ! is_venv_activated; then
echo "Activating virtual environment..."
source venv/bin/activate
fi
# Check if requirements are already installed
echo "Checking requirements..."
if ! pip show -r requirements.txt >/dev/null 2>&1; then
echo "Installing requirements..."
pip install -r requirements.txt >/dev/null 2>&1
else
echo "All packages are already installed."
fi
# Run test.py using python
#echo "Running test.py with python..."
#python test.py
#
## If the above command fails, run test.py using python3
#if [ $? -ne 0 ]; then
# echo "Running test.py with python3..."
# python3 test.py
#fi
if [ "$1" = "ui" ]; then
echo "Running UI..."
python ui.py
if [ $? -ne 0 ]; then
echo "Running UI with python3..."
python3 ui.py
fi
fi
if [ "$1" = "cli" ]; then
echo "Running superagi cli..."
python cli2.py
# If the above command fails, run test.py using python3
if [ $? -ne 0 ]; then
echo "Running superagi cli..."
python3 cli2.py
fi
fi
================================================
FILE: run_gui.py
================================================
import os
import sys
import subprocess
from time import sleep
import shutil
from superagi.lib.logger import logger
def check_command(command, message):
if not shutil.which(command):
logger.info(message)
sys.exit(1)
def run_npm_commands():
os.chdir("gui")
try:
subprocess.run(["npm", "install"], check=True)
except subprocess.CalledProcessError:
logger.error(f"Error during '{' '.join(sys.exc_info()[1].cmd)}'. Exiting.")
sys.exit(1)
os.chdir("..")
def run_server():
api_process = subprocess.Popen(["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"])
os.chdir("gui")
ui_process = subprocess.Popen(["npm", "run", "dev"])
os.chdir("..")
return api_process, ui_process
def cleanup(api_process, ui_process):
logger.info("Shutting down processes...")
api_process.terminate()
ui_process.terminate()
logger.info("Processes terminated. Exiting.")
sys.exit(1)
if __name__ == "__main__":
check_command("node", "Node.js is not installed. Please install it and try again.")
check_command("npm", "npm is not installed. Please install npm to proceed.")
check_command("uvicorn", "uvicorn is not installed. Please install uvicorn to proceed.")
run_npm_commands()
try:
api_process, ui_process = run_server()
while True:
try:
sleep(30)
except KeyboardInterrupt:
cleanup(api_process, ui_process)
except Exception as e:
cleanup(api_process, ui_process)
================================================
FILE: run_gui.sh
================================================
#!/bin/bash
api_process=""
ui_process=""
function check_command() {
command -v "$1" >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "$1 is not installed. Please install $1 to proceed."
exit 1
fi
}
function run_npm_commands() {
cd gui
npm install
if [ $? -ne 0 ]; then
echo "Error during 'npm install'. Exiting."
exit 1
fi
npm run build
if [ $? -ne 0 ]; then
echo "Error during 'npm run build'. Exiting."
exit 1
fi
cd ..
}
function run_server() {
uvicorn main:app --host 0.0.0.0 --port 8000 &
api_process=$!
cd gui && npm run dev &
ui_process=$!
}
function cleanup() {
echo "Shutting down processes..."
kill $api_process
kill $ui_process
echo "Processes terminated. Exiting."
exit 1
}
trap cleanup SIGINT
check_command "node"
check_command "npm"
check_command "uvicorn"
run_npm_commands
run_server
wait $api_process
================================================
FILE: superagi/__init__.py
================================================
================================================
FILE: superagi/agent/__init__.py
================================================
================================================
FILE: superagi/agent/agent_iteration_step_handler.py
================================================
from datetime import datetime
import json
from sqlalchemy import asc
from sqlalchemy.sql.operators import and_
import logging
import superagi
from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.agent.output_handler import ToolOutputHandler, get_output_handler
from superagi.agent.task_queue import TaskQueue
from superagi.agent.tool_builder import ToolBuilder
from superagi.apm.event_handler import EventHandler
from superagi.config.config import get_config
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.tools.resource.query_resource import QueryResourceTool
from superagi.tools.thinking.tools import ThinkingTool
from superagi.apm.call_log_helper import CallLogHelper
class AgentIterationStepHandler:
""" Handles iteration workflow steps in the agent workflow."""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int, memory=None):
self.session = session
self.llm = llm
self.agent_execution_id = agent_execution_id
self.agent_id = agent_id
self.memory = memory
self.organisation = Agent.find_org_by_agent_id(self.session, agent_id=self.agent_id)
self.task_queue = TaskQueue(str(self.agent_execution_id))
def execute_step(self):
agent_config = Agent.fetch_configuration(self.session, self.agent_id)
execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
iteration_workflow_step = IterationWorkflowStep.find_by_id(self.session, execution.iteration_workflow_step_id)
agent_execution_config = AgentExecutionConfiguration.fetch_configuration(self.session, self.agent_execution_id)
if not self._handle_wait_for_permission(execution, agent_config, agent_execution_config,
iteration_workflow_step):
return
workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
organisation = Agent.find_org_by_agent_id(self.session, agent_id=self.agent_id)
iteration_workflow = IterationWorkflow.find_by_id(self.session, workflow_step.action_reference_id)
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
if not agent_feeds:
self.task_queue.clear_tasks()
agent_tools = self._build_tools(agent_config, agent_execution_config)
prompt = self._build_agent_prompt(iteration_workflow=iteration_workflow,
agent_config=agent_config,
agent_execution_config=agent_execution_config,
prompt=iteration_workflow_step.prompt,
agent_tools=agent_tools)
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=iteration_workflow_step.history_enabled,
completion_prompt=iteration_workflow_step.completion_prompt)
logger.debug("Prompt messages:", messages)
current_tokens = TokenCounter.count_message_tokens(messages = messages, model = self.llm.get_model())
response = self.llm.chat_completion(messages, TokenCounter(session=self.session, organisation_id=organisation.id).token_limit(self.llm.get_model()) - current_tokens)
if 'error' in response and response['message'] is not None:
ErrorHandler.handle_openai_errors(self.session, self.agent_id, self.agent_execution_id, response['message'])
if 'content' not in response or response['content'] is None:
raise RuntimeError(f"Failed to get response from llm")
total_tokens = current_tokens + TokenCounter.count_message_tokens(response['content'], self.llm.get_model())
AgentExecution.update_tokens(self.session, self.agent_execution_id, total_tokens)
try:
content = json.loads(response['content'])
tool = content.get('tool', {})
tool_name = tool.get('name', '') if tool else ''
except json.JSONDecodeError:
print("Decoding JSON has failed")
tool_name = ''
CallLogHelper(session=self.session, organisation_id=organisation.id).create_call_log(execution.name,
agent_config['agent_id'], total_tokens, tool_name, agent_config['model'])
assistant_reply = response['content']
output_handler = get_output_handler(iteration_workflow_step.output_type,
agent_execution_id=self.agent_execution_id,
agent_config=agent_config,memory=self.memory, agent_tools=agent_tools)
response = output_handler.handle(self.session, assistant_reply)
if response.status == "COMPLETE":
execution.status = "COMPLETED"
self.session.commit()
self._update_agent_execution_next_step(execution, iteration_workflow_step.next_step_id, "COMPLETE")
EventHandler(session=self.session).create_event('run_completed',
{'agent_execution_id': execution.id,
'name': execution.name,
'tokens_consumed': execution.num_of_tokens,
"calls": execution.num_of_calls},
execution.agent_id, organisation.id)
elif response.status == "WAITING_FOR_PERMISSION":
execution.status = "WAITING_FOR_PERMISSION"
execution.permission_id = response.permission_id
self.session.commit()
else:
# moving to next step of iteration or workflow
self._update_agent_execution_next_step(execution, iteration_workflow_step.next_step_id)
logger.info(f"Starting next job for agent execution id: {self.agent_execution_id}")
self.session.flush()
def _update_agent_execution_next_step(self, execution, next_step_id, step_response: str = "default"):
if next_step_id == -1:
next_step = AgentWorkflowStep.fetch_next_step(self.session, execution.current_agent_step_id, step_response)
if str(next_step) == "COMPLETE":
execution.current_agent_step_id = -1
execution.status = "COMPLETED"
else:
AgentExecution.assign_next_step_id(self.session, self.agent_execution_id, next_step.id)
else:
execution.iteration_workflow_step_id = next_step_id
self.session.commit()
def _build_agent_prompt(self, iteration_workflow: IterationWorkflow, agent_config: dict,
agent_execution_config: dict,
prompt: str, agent_tools: list):
max_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 600))
prompt = AgentPromptBuilder.replace_main_variables(prompt, agent_execution_config["goal"],
agent_execution_config["instruction"],
agent_config["constraints"], agent_tools,
(not iteration_workflow.has_task_queue))
if iteration_workflow.has_task_queue:
response = self.task_queue.get_last_task_details()
last_task, last_task_result = (response["task"], response["response"]) if response is not None else ("", "")
current_task = self.task_queue.get_first_task() or ""
token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit() - max_token_limit
prompt = AgentPromptBuilder.replace_task_based_variables(prompt, current_task, last_task, last_task_result,
self.task_queue.get_tasks(),
self.task_queue.get_completed_tasks(), token_limit)
return prompt
def _build_tools(self, agent_config: dict, agent_execution_config: dict):
agent_tools = [ThinkingTool()]
config_data = AgentConfiguration.get_model_api_key(self.session, self.agent_id, agent_config["model"])
model_api_key = config_data['api_key']
tool_builder = ToolBuilder(self.session, self.agent_id, self.agent_execution_id)
resource_summary = ResourceSummarizer(session=self.session, agent_id=self.agent_id, model=agent_config['model']).fetch_or_create_agent_resource_summary(default_summary=agent_config.get("resource_summary"))
if resource_summary is not None:
agent_tools.append(QueryResourceTool())
user_tools = self.session.query(Tool).filter(
and_(Tool.id.in_(agent_execution_config["tools"]), Tool.file_name is not None)).all()
for tool in user_tools:
agent_tools.append(tool_builder.build_tool(tool))
agent_tools = [tool_builder.set_default_params_tool(tool, agent_config, agent_execution_config,
model_api_key, resource_summary,self.memory) for tool in agent_tools]
return agent_tools
def _handle_wait_for_permission(self, agent_execution, agent_config: dict, agent_execution_config: dict,
iteration_workflow_step: IterationWorkflowStep):
"""
Handles the wait for permission when the agent execution is waiting for permission.
Args:
agent_execution (AgentExecution): The agent execution.
agent_config (dict): The agent configuration.
agent_execution_config (dict): The agent execution configuration.
iteration_workflow_step (IterationWorkflowStep): The iteration workflow step.
Raises:
Returns permission success or failure
"""
if agent_execution.status != "WAITING_FOR_PERMISSION":
return True
agent_execution_permission = self.session.query(AgentExecutionPermission).filter(
AgentExecutionPermission.id == agent_execution.permission_id).first()
if agent_execution_permission.status == "PENDING":
logger.error("handle_wait_for_permission: Permission is still pending")
return False
if agent_execution_permission.status == "APPROVED":
agent_tools = self._build_tools(agent_config, agent_execution_config)
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, agent_tools,self.memory)
tool_result = tool_output_handler.handle_tool_response(self.session,
agent_execution_permission.assistant_reply)
result = tool_result.result
else:
result = f"User denied the permission to run the tool {agent_execution_permission.tool_name}" \
f"{' and has given the following feedback : ' + agent_execution_permission.user_feedback if agent_execution_permission.user_feedback else ''}"
agent_execution_feed = AgentExecutionFeed(agent_execution_id=agent_execution_permission.agent_execution_id,
agent_id=agent_execution_permission.agent_id,
feed=agent_execution_permission.assistant_reply,
role="assistant",
feed_group_id=agent_execution.current_feed_group_id)
self.session.add(agent_execution_feed)
agent_execution_feed1 = AgentExecutionFeed(agent_execution_id=agent_execution_permission.agent_execution_id,
agent_id=agent_execution_permission.agent_id,
feed=result, role="user",
feed_group_id=agent_execution.current_feed_group_id)
self.session.add(agent_execution_feed1)
agent_execution.status = "RUNNING"
execution = AgentExecution.find_by_id(self.session, agent_execution_permission.agent_execution_id)
self._update_agent_execution_next_step(execution, iteration_workflow_step.next_step_id)
self.session.commit()
return True
================================================
FILE: superagi/agent/agent_message_builder.py
================================================
import time
from typing import Tuple, List
from sqlalchemy import asc
from superagi.config.config import get_config
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.types.common import BaseMessage
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent import Agent
class AgentLlmMessageBuilder:
"""Agent message builder for LLM agent."""
def __init__(self, session, llm, llm_model: str, agent_id: int, agent_execution_id: int):
self.session = session
self.llm = llm
self.llm_model = llm_model
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id
self.organisation = Agent.find_org_by_agent_id(self.session, self.agent_id)
def build_agent_messages(self, prompt: str, agent_feeds: list, history_enabled=False,
completion_prompt: str = None):
""" Build agent messages for LLM agent.
Args:
prompt (str): The prompt to be used for generating the agent messages.
agent_feeds (list): The list of agent feeds.
history_enabled (bool): Whether to use history or not.
completion_prompt (str): The completion prompt to be used for generating the agent messages.
"""
token_limit = TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm_model)
max_output_token_limit = int(get_config("MAX_TOOL_TOKEN_LIMIT", 800))
messages = [{"role": "system", "content": prompt}]
if history_enabled:
messages.append({"role": "system", "content": f"The current time and date is {time.strftime('%c')}"})
base_token_limit = TokenCounter.count_message_tokens(messages, self.llm_model)
full_message_history = [{'role': agent_feed.role, 'content': agent_feed.feed, 'chat_id': agent_feed.id}
for agent_feed in agent_feeds]
past_messages, current_messages = self._split_history(full_message_history,
((token_limit - base_token_limit - max_output_token_limit) // 4) * 3)
if past_messages:
ltm_summary = self._build_ltm_summary(past_messages=past_messages,
output_token_limit=(token_limit - base_token_limit - max_output_token_limit) // 4)
messages.append({"role": "assistant", "content": ltm_summary})
for history in current_messages:
messages.append({"role": history["role"], "content": history["content"]})
messages.append({"role": "user", "content": completion_prompt})
# insert initial agent feeds
self._add_initial_feeds(agent_feeds, messages)
return messages
def _split_history(self, history: List, pending_token_limit: int) -> Tuple[List[BaseMessage], List[BaseMessage]]:
hist_token_count = 0
i = len(history)
for message in reversed(history):
token_count = TokenCounter.count_message_tokens([{"role": message["role"], "content": message["content"]}],
self.llm_model)
hist_token_count += token_count
if hist_token_count > pending_token_limit:
self._add_or_update_last_agent_feed_ltm_summary_id(str(history[i-1]['chat_id']))
return history[:i], history[i:]
i -= 1
return [], history
def _add_initial_feeds(self, agent_feeds: list, messages: list):
if agent_feeds:
return
for message in messages:
agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_execution_id,
agent_id=self.agent_id,
feed=message["content"],
role=message["role"],
feed_group_id="DEFAULT")
self.session.add(agent_execution_feed)
self.session.commit()
def _add_or_update_last_agent_feed_ltm_summary_id(self, last_agent_feed_ltm_summary_id):
execution = AgentExecution(id=self.agent_execution_id)
agent_execution_configs = {"last_agent_feed_ltm_summary_id": last_agent_feed_ltm_summary_id}
AgentExecutionConfiguration.add_or_update_agent_execution_config(self.session, execution,
agent_execution_configs)
def _build_ltm_summary(self, past_messages, output_token_limit) -> str:
ltm_prompt = self._build_prompt_for_ltm_summary(past_messages=past_messages,
token_limit=output_token_limit)
summary = AgentExecutionConfiguration.fetch_value(self.session, self.agent_execution_id, "ltm_summary")
previous_ltm_summary = summary.value if summary is not None else ""
ltm_summary_base_token_limit = 10
if ((TokenCounter.count_text_tokens(ltm_prompt) + ltm_summary_base_token_limit + output_token_limit)
- TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm_model)) > 0:
last_agent_feed_ltm_summary_id = AgentExecutionConfiguration.fetch_value(self.session,
self.agent_execution_id, "last_agent_feed_ltm_summary_id")
last_agent_feed_ltm_summary_id = (
int(last_agent_feed_ltm_summary_id.value)
if last_agent_feed_ltm_summary_id is not None and last_agent_feed_ltm_summary_id.value is not None
else 0
)
past_messages = self.session.query(AgentExecutionFeed.role, AgentExecutionFeed.feed,
AgentExecutionFeed.id) \
.filter(AgentExecutionFeed.agent_execution_id == self.agent_execution_id,
AgentExecutionFeed.id > last_agent_feed_ltm_summary_id) \
.order_by(asc(AgentExecutionFeed.created_at)) \
.all()
past_messages = [
{'role': past_message.role, 'content': past_message.feed, 'chat_id': past_message.id}
for past_message in past_messages]
ltm_prompt = self._build_prompt_for_recursive_ltm_summary_using_previous_ltm_summary(
previous_ltm_summary=previous_ltm_summary, past_messages=past_messages, token_limit=output_token_limit)
msgs = [{"role": "system", "content": "You are GPT Prompt writer"},
{"role": "assistant", "content": ltm_prompt}]
ltm_summary = self.llm.chat_completion(msgs)
if 'error' in ltm_summary and ltm_summary['message'] is not None:
ErrorHandler.handle_openai_errors(self.session, self.agent_id, self.agent_execution_id, ltm_summary['message'])
execution = AgentExecution(id=self.agent_execution_id)
agent_execution_configs = {"ltm_summary": ltm_summary["content"]}
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=self.session, execution=execution,
agent_execution_configs=agent_execution_configs)
return ltm_summary["content"]
def _build_prompt_for_ltm_summary(self, past_messages: List[BaseMessage], token_limit: int):
ltm_summary_prompt = PromptReader.read_agent_prompt(__file__, "agent_summary.txt")
past_messages_prompt = ""
for past_message in past_messages:
past_messages_prompt += past_message["role"] + ": " + past_message["content"] + "\n"
ltm_summary_prompt = ltm_summary_prompt.replace("{past_messages}", past_messages_prompt)
ltm_summary_prompt = ltm_summary_prompt.replace("{char_limit}", str(token_limit*4))
return ltm_summary_prompt
def _build_prompt_for_recursive_ltm_summary_using_previous_ltm_summary(self, previous_ltm_summary: str,
past_messages: List[BaseMessage], token_limit: int):
ltm_summary_prompt = PromptReader.read_agent_prompt(__file__, "agent_recursive_summary.txt")
ltm_summary_prompt = ltm_summary_prompt.replace("{previous_ltm_summary}", previous_ltm_summary)
past_messages_prompt = ""
for past_message in past_messages:
past_messages_prompt += past_message["role"] + ": " + past_message["content"] + "\n"
ltm_summary_prompt = ltm_summary_prompt.replace("{past_messages}", past_messages_prompt)
ltm_summary_prompt = ltm_summary_prompt.replace("{char_limit}", str(token_limit*4))
return ltm_summary_prompt
================================================
FILE: superagi/agent/agent_prompt_builder.py
================================================
import json
import re
from pydantic.types import List
from superagi.helper.token_counter import TokenCounter
from superagi.tools.base_tool import BaseTool
FINISH_NAME = "finish"
class AgentPromptBuilder:
"""Agent prompt builder for LLM agent."""
@staticmethod
def add_list_items_to_string(items: List[str]) -> str:
list_string = ""
for i, item in enumerate(items):
list_string += f"{i + 1}. {item}\n"
return list_string
@classmethod
def add_tools_to_prompt(cls, tools: List[BaseTool], add_finish: bool = True) -> str:
"""Add tools to the prompt.
Args:
tools (List[BaseTool]): The list of tools.
add_finish (bool): Whether to add finish tool or not.
"""
final_string = ""
print(tools)
for i, item in enumerate(tools):
final_string += f"{i + 1}. {cls._generate_tool_string(item)}\n"
finish_description = (
"use this to signal that you have finished all your objectives"
)
finish_args = (
'"response": "final response to let '
'people know you have finished your objectives"'
)
finish_string = (
f"{len(tools) + 1}. \"{FINISH_NAME}\": "
f"{finish_description}, args: {finish_args}"
)
if add_finish:
final_string = final_string + finish_string + "\n\n"
else:
final_string = final_string + "\n"
return final_string
@classmethod
def _generate_tool_string(cls, tool: BaseTool) -> str:
output = f"\"{tool.name}\": {tool.description}"
# print(tool.args)
output += f", args json schema: {json.dumps(tool.args)}"
return output
@classmethod
def clean_prompt(cls, prompt):
prompt = re.sub('[ \t]+', ' ', prompt)
return prompt.strip()
@classmethod
def replace_main_variables(cls, super_agi_prompt: str, goals: List[str], instructions: List[str], constraints: List[str],
tools: List[BaseTool], add_finish_tool: bool = True):
"""Replace the main variables in the super agi prompt.
Args:
super_agi_prompt (str): The super agi prompt.
goals (List[str]): The list of goals.
instructions (List[str]): The list of instructions.
constraints (List[str]): The list of constraints.
tools (List[BaseTool]): The list of tools.
add_finish_tool (bool): Whether to add finish tool or not.
"""
super_agi_prompt = super_agi_prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(goals))
if len(instructions) > 0 and len(instructions[0]) > 0:
task_str = "INSTRUCTION(Follow these instruction to decide the flow of execution and decide the next steps for achieving the task):"
super_agi_prompt = super_agi_prompt.replace("{instructions}", "INSTRUCTION: " + '\n' + AgentPromptBuilder.add_list_items_to_string(instructions))
super_agi_prompt = super_agi_prompt.replace("{task_instructions}", task_str + '\n' + AgentPromptBuilder.add_list_items_to_string(instructions))
else:
super_agi_prompt = super_agi_prompt.replace("{instructions}", '')
super_agi_prompt = super_agi_prompt.replace("{task_instructions}", "")
super_agi_prompt = super_agi_prompt.replace("{constraints}",
AgentPromptBuilder.add_list_items_to_string(constraints))
# logger.info(tools)
tools_string = AgentPromptBuilder.add_tools_to_prompt(tools, add_finish_tool)
super_agi_prompt = super_agi_prompt.replace("{tools}", tools_string)
return super_agi_prompt
@classmethod
def replace_task_based_variables(cls, super_agi_prompt: str, current_task: str, last_task: str,
last_task_result: str, pending_tasks: List[str], completed_tasks: list, token_limit: int):
"""Replace the task based variables in the super agi prompt.
Args:
super_agi_prompt (str): The super agi prompt.
current_task (str): The current task.
last_task (str): The last task.
last_task_result (str): The last task result.
pending_tasks (List[str]): The list of pending tasks.
completed_tasks (list): The list of completed tasks.
token_limit (int): The token limit.
"""
if "{current_task}" in super_agi_prompt:
super_agi_prompt = super_agi_prompt.replace("{current_task}", current_task)
if "{last_task}" in super_agi_prompt:
super_agi_prompt = super_agi_prompt.replace("{last_task}", last_task)
if "{last_task_result}" in super_agi_prompt:
super_agi_prompt = super_agi_prompt.replace("{last_task_result}", last_task_result)
if "{pending_tasks}" in super_agi_prompt:
super_agi_prompt = super_agi_prompt.replace("{pending_tasks}", str(pending_tasks))
completed_tasks.reverse()
if "{completed_tasks}" in super_agi_prompt:
completed_tasks_arr = []
for task in completed_tasks:
completed_tasks_arr.append(task['task'])
super_agi_prompt = super_agi_prompt.replace("{completed_tasks}", str(completed_tasks_arr))
base_token_limit = TokenCounter.count_message_tokens([{"role": "user", "content": super_agi_prompt}])
pending_tokens = token_limit - base_token_limit
final_output = ""
if "{task_history}" in super_agi_prompt:
for task in reversed(completed_tasks[-10:]):
final_output = f"Task: {task['task']}\nResult: {task['response']}\n" + final_output
token_count = TokenCounter.count_message_tokens([{"role": "user", "content": final_output}])
# giving buffer of 100 tokens
if token_count > min(600, pending_tokens):
break
super_agi_prompt = super_agi_prompt.replace("{task_history}", "\n" + final_output + "\n")
return super_agi_prompt
================================================
FILE: superagi/agent/agent_prompt_template.py
================================================
import re
from pydantic.types import List
from superagi.helper.prompt_reader import PromptReader
FINISH_NAME = "finish"
class AgentPromptTemplate:
@staticmethod
def add_list_items_to_string(items: List[str]) -> str:
list_string = ""
for i, item in enumerate(items):
list_string += f"{i + 1}. {item}\n"
return list_string
@classmethod
def clean_prompt(cls, prompt):
prompt = re.sub('[ \t]+', ' ', prompt)
return prompt.strip()
@classmethod
def get_super_agi_single_prompt(cls):
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "superagi.txt")
return {"prompt": super_agi_prompt, "variables": ["goals", "instructions", "constraints", "tools"]}
@classmethod
def start_task_based(cls):
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "initialize_tasks.txt")
return {"prompt": AgentPromptTemplate.clean_prompt(super_agi_prompt), "variables": ["goals", "instructions"]}
# super_agi_prompt = super_agi_prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(goals))
@classmethod
def analyse_task(cls):
constraints = [
'Exclusively use the tools listed in double quotes e.g. "tool name"'
]
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "analyse_task.txt")
super_agi_prompt = AgentPromptTemplate.clean_prompt(super_agi_prompt) \
.replace("{constraints}", AgentPromptTemplate.add_list_items_to_string(constraints))
return {"prompt": super_agi_prompt, "variables": ["goals", "instructions", "tools", "current_task"]}
@classmethod
def create_tasks(cls):
# just executed task `{last_task}` and got the result `{last_task_result}`
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "create_tasks.txt")
return {"prompt": AgentPromptTemplate.clean_prompt(super_agi_prompt),
"variables": ["goals", "instructions", "last_task", "last_task_result", "pending_tasks"]}
@classmethod
def prioritize_tasks(cls):
# just executed task `{last_task}` and got the result `{last_task_result}`
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "prioritize_tasks.txt")
return {"prompt": AgentPromptTemplate.clean_prompt(super_agi_prompt),
"variables": ["goals", "instructions", "last_task", "last_task_result", "pending_tasks"]}
================================================
FILE: superagi/agent/agent_tool_step_handler.py
================================================
import json
from superagi.agent.task_queue import TaskQueue
from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.agent.output_handler import ToolOutputHandler
from superagi.agent.output_parser import AgentSchemaToolOutputParser
from superagi.agent.queue_step_handler import QueueStepHandler
from superagi.agent.tool_builder import ToolBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_tool import AgentWorkflowStepTool
from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.tools.base_tool import BaseTool
from sqlalchemy import and_
class AgentToolStepHandler:
"""Handles the tools steps in the agent workflow"""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int, memory=None):
self.session = session
self.llm = llm
self.agent_execution_id = agent_execution_id
self.agent_id = agent_id
self.memory = memory
self.task_queue = TaskQueue(str(self.agent_execution_id))
self.organisation = Agent.find_org_by_agent_id(self.session, self.agent_id)
def execute_step(self):
execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
step_tool = AgentWorkflowStepTool.find_by_id(self.session, workflow_step.action_reference_id)
agent_config = Agent.fetch_configuration(self.session, self.agent_id)
agent_execution_config = AgentExecutionConfiguration.fetch_configuration(self.session, self.agent_execution_id)
# print(agent_execution_config)
if not self._handle_wait_for_permission(execution, workflow_step):
return
if step_tool.tool_name == "TASK_QUEUE":
step_response = QueueStepHandler(self.session, self.llm, self.agent_id, self.agent_execution_id).execute_step()
next_step = AgentWorkflowStep.fetch_next_step(self.session, workflow_step.id, step_response)
self._handle_next_step(next_step)
return
if step_tool.tool_name == "WAIT_FOR_PERMISSION":
self._create_permission_request(execution, step_tool)
return
assistant_reply = self._process_input_instruction(agent_config, agent_execution_config, step_tool,
workflow_step)
tool_obj = self._build_tool_obj(agent_config, agent_execution_config, step_tool.tool_name)
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, [tool_obj],self.memory,
output_parser=AgentSchemaToolOutputParser())
final_response = tool_output_handler.handle(self.session, assistant_reply)
step_response = "default"
if step_tool.output_instruction:
step_response = self._process_output_instruction(final_response.result, step_tool, workflow_step)
next_step = AgentWorkflowStep.fetch_next_step(self.session, workflow_step.id, step_response)
self._handle_next_step(next_step)
self.session.flush()
def _create_permission_request(self, execution, step_tool: AgentWorkflowStepTool):
new_agent_execution_permission = AgentExecutionPermission(
agent_execution_id=self.agent_execution_id,
status="PENDING",
agent_id=self.agent_id,
tool_name="WAIT_FOR_PERMISSION",
question=step_tool.input_instruction,
assistant_reply="")
self.session.add(new_agent_execution_permission)
self.session.commit()
self.session.flush()
execution.permission_id = new_agent_execution_permission.id
execution.status = "WAITING_FOR_PERMISSION"
self.session.commit()
def _handle_next_step(self, next_step):
if str(next_step) == "COMPLETE":
agent_execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
agent_execution.current_agent_step_id = -1
agent_execution.status = "COMPLETED"
else:
AgentExecution.assign_next_step_id(self.session, self.agent_execution_id, next_step.id)
self.session.commit()
def _process_input_instruction(self, agent_config, agent_execution_config, step_tool, workflow_step):
tool_obj = self._build_tool_obj(agent_config, agent_execution_config, step_tool.tool_name)
prompt = self._build_tool_input_prompt(step_tool, tool_obj, agent_execution_config)
logger.info("Prompt: ", prompt)
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled,
completion_prompt=step_tool.completion_prompt)
# print(messages)
current_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
response = self.llm.chat_completion(messages, TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm.get_model()) - current_tokens)
if 'error' in response and response['message'] is not None:
ErrorHandler.handle_openai_errors(self.session, self.agent_id, self.agent_execution_id, response['message'])
# ModelsHelper(session=self.session, organisation_id=organisation.id).create_call_log(execution.name,agent_config['agent_id'],response['response'].usage.total_tokens,json.loads(response['content'])['tool']['name'],agent_config['model'])
if 'content' not in response or response['content'] is None:
raise RuntimeError(f"Failed to get response from llm")
total_tokens = current_tokens + TokenCounter.count_message_tokens(response, self.llm.get_model())
AgentExecution.update_tokens(self.session, self.agent_execution_id, total_tokens)
assistant_reply = response['content']
return assistant_reply
def _build_tool_obj(self, agent_config, agent_execution_config, tool_name: str):
model_api_key = AgentConfiguration.get_model_api_key(self.session, self.agent_id, agent_config["model"])['api_key']
tool_builder = ToolBuilder(self.session, self.agent_id, self.agent_execution_id)
resource_summary = ""
if tool_name == "QueryResourceTool":
resource_summary = ResourceSummarizer(session=self.session,
agent_id=self.agent_id,
model=agent_config["model"]).fetch_or_create_agent_resource_summary(
default_summary=agent_config.get("resource_summary"))
organisation = Agent.find_org_by_agent_id(self.session, self.agent_id)
tool = self.session.query(Tool).join(Toolkit, and_(Tool.toolkit_id == Toolkit.id, Toolkit.organisation_id == organisation.id, Tool.name == tool_name)).first()
tool_obj = tool_builder.build_tool(tool)
tool_obj = tool_builder.set_default_params_tool(tool_obj, agent_config, agent_execution_config, model_api_key,
resource_summary,self.memory)
return tool_obj
def _process_output_instruction(self, final_response: str, step_tool: AgentWorkflowStepTool,
workflow_step: AgentWorkflowStep):
prompt = self._build_tool_output_prompt(step_tool, final_response, workflow_step)
messages = [{"role": "system", "content": prompt}]
current_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
response = self.llm.chat_completion(messages,
TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm.get_model()) - current_tokens)
if 'error' in response and response['message'] is not None:
ErrorHandler.handle_openai_errors(self.session, self.agent_id, self.agent_execution_id, response['message'])
if 'content' not in response or response['content'] is None:
raise RuntimeError(f"ToolWorkflowStepHandler: Failed to get output response from llm")
total_tokens = current_tokens + TokenCounter.count_message_tokens(response, self.llm.get_model())
AgentExecution.update_tokens(self.session, self.agent_execution_id, total_tokens)
step_response = response['content']
step_response = step_response.replace("'", "").replace("\"", "")
return step_response
def _build_tool_input_prompt(self, step_tool: AgentWorkflowStepTool, tool: BaseTool, agent_execution_config: dict):
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "agent_tool_input.txt")
super_agi_prompt = super_agi_prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(
agent_execution_config["goal"]))
super_agi_prompt = super_agi_prompt.replace("{tool_name}", step_tool.tool_name)
super_agi_prompt = super_agi_prompt.replace("{instruction}", step_tool.input_instruction)
tool_schema = f"\"{tool.name}\": {tool.description}, args json schema: {json.dumps(tool.args)}"
super_agi_prompt = super_agi_prompt.replace("{tool_schema}", tool_schema)
return super_agi_prompt
def _get_step_responses(self, workflow_step: AgentWorkflowStep):
return [step["step_response"] for step in workflow_step.next_steps]
def _build_tool_output_prompt(self, step_tool: AgentWorkflowStepTool, tool_output: str,
workflow_step: AgentWorkflowStep):
super_agi_prompt = PromptReader.read_agent_prompt(__file__, "agent_tool_output.txt")
super_agi_prompt = super_agi_prompt.replace("{tool_output}", tool_output)
super_agi_prompt = super_agi_prompt.replace("{tool_name}", step_tool.tool_name)
super_agi_prompt = super_agi_prompt.replace("{instruction}", step_tool.output_instruction)
step_responses = self._get_step_responses(workflow_step)
if "default" in step_responses:
step_responses.remove("default")
super_agi_prompt = super_agi_prompt.replace("{output_options}", str(step_responses))
return super_agi_prompt
def _handle_wait_for_permission(self, agent_execution, workflow_step: AgentWorkflowStep):
"""
Handles the wait for permission when the agent execution is waiting for permission.
Args:
agent_execution (AgentExecution): The agent execution.
workflow_step (AgentWorkflowStep): The workflow step.
Raises:
Returns permission success or failure
"""
if agent_execution.status != "WAITING_FOR_PERMISSION":
return True
agent_execution_permission = self.session.query(AgentExecutionPermission).filter(
AgentExecutionPermission.id == agent_execution.permission_id).first()
if agent_execution_permission.status == "PENDING":
logger.error("handle_wait_for_permission: Permission is still pending")
return False
if agent_execution_permission.status == "APPROVED":
next_step = AgentWorkflowStep.fetch_next_step(self.session, workflow_step.id, "YES")
else:
next_step = AgentWorkflowStep.fetch_next_step(self.session, workflow_step.id, "NO")
result = f"{' User has given the following feedback : ' + agent_execution_permission.user_feedback if agent_execution_permission.user_feedback else ''}"
agent_execution_feed = AgentExecutionFeed(agent_execution_id=agent_execution_permission.agent_execution_id,
agent_id=agent_execution_permission.agent_id,
feed=result, role="user",
feed_group_id=agent_execution.current_feed_group_id)
self.session.add(agent_execution_feed)
agent_execution.status = "RUNNING"
agent_execution.permission_id = -1
self.session.commit()
self._handle_next_step(next_step)
self.session.commit()
return False
================================================
FILE: superagi/agent/agent_workflow_step_wait_handler.py
================================================
from datetime import datetime
from superagi.agent.types.agent_execution_status import AgentExecutionStatus
from superagi.lib.logger import logger
from superagi.models.agent_execution import AgentExecution
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_wait import AgentWorkflowStepWait
from superagi.agent.types.wait_step_status import AgentWorkflowStepWaitStatus
class AgentWaitStepHandler:
"""Handle Agent Wait Step in the agent workflow."""
def __init__(self, session, agent_id, agent_execution_id):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id
def execute_step(self):
"""Execute the agent wait step."""
logger.info("Executing Wait Step")
execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
step_wait = AgentWorkflowStepWait.find_by_id(self.session, workflow_step.action_reference_id)
if step_wait is not None:
step_wait.wait_begin_time = datetime.now()
step_wait.status = AgentWorkflowStepWaitStatus.WAITING.value
execution.status = AgentExecutionStatus.WAIT_STEP.value
self.session.commit()
def handle_next_step(self):
"""Handle next step of agent workflow in case of wait step."""
execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
step_response = "default"
next_step = AgentWorkflowStep.fetch_next_step(self.session, workflow_step.id, step_response)
if str(next_step) == "COMPLETE":
agent_execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
agent_execution.current_agent_step_id = -1
agent_execution.status = "COMPLETED"
else:
AgentExecution.assign_next_step_id(self.session, self.agent_execution_id, next_step.id)
self.session.commit()
================================================
FILE: superagi/agent/common_types.py
================================================
from pydantic import BaseModel
class ToolExecutorResponse(BaseModel):
status: str
result: str = None
retry: bool = False
is_permission_required: bool = False
permission_id: int = None
class TaskExecutorResponse(BaseModel):
status: str
retry: bool
================================================
FILE: superagi/agent/output_handler.py
================================================
import json
from superagi.agent.common_types import TaskExecutorResponse, ToolExecutorResponse
from superagi.agent.output_parser import AgentSchemaOutputParser
from superagi.agent.task_queue import TaskQueue
from superagi.agent.tool_executor import ToolExecutor
from superagi.helper.json_cleaner import JsonCleaner
from superagi.lib.logger import logger
from langchain.text_splitter import TokenTextSplitter
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.vector_store.base import VectorStore
import numpy as np
from superagi.models.agent_execution_permission import AgentExecutionPermission
class ToolOutputHandler:
"""Handles the tool output response from the thinking step"""
def __init__(self,
agent_execution_id: int,
agent_config: dict,
tools: list,
memory:VectorStore=None,
output_parser=AgentSchemaOutputParser()):
self.agent_execution_id = agent_execution_id
self.task_queue = TaskQueue(str(agent_execution_id))
self.agent_config = agent_config
self.tools = tools
self.output_parser = output_parser
self.memory=memory
def handle(self, session, assistant_reply):
"""Handles the tool output response from the thinking step.
Step takes care of permission control as well at tool level.
Args:
session (Session): The database session.
assistant_reply (str): The assistant reply.
"""
response = self._check_permission_in_restricted_mode(session, assistant_reply)
if response.is_permission_required:
return response
tool_response = self.handle_tool_response(session, assistant_reply)
# print(tool_response)
agent_execution = AgentExecution.find_by_id(session, self.agent_execution_id)
agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_execution_id,
agent_id=self.agent_config["agent_id"],
feed=assistant_reply,
role="assistant",
feed_group_id=agent_execution.current_feed_group_id)
session.add(agent_execution_feed)
tool_response_feed = AgentExecutionFeed(agent_execution_id=self.agent_execution_id,
agent_id=self.agent_config["agent_id"],
feed=tool_response.result,
role="system",
feed_group_id=agent_execution.current_feed_group_id)
session.add(tool_response_feed)
session.commit()
if not tool_response.retry:
tool_response = self._check_for_completion(tool_response)
self.add_text_to_memory(assistant_reply, tool_response.result)
return tool_response
def add_text_to_memory(self, assistant_reply,tool_response_result):
"""
Adds the text generated by the assistant and tool response to the memory.
Args:
assistant_reply (str): The assistant reply.
tool_response_result (str): The tool response.
Returns:
None
"""
if self.memory is not None:
try:
data = json.loads(assistant_reply)
task_description = data['thoughts']['text']
final_tool_response = tool_response_result
prompt = task_description + final_tool_response
text_splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=10)
chunk_response = text_splitter.split_text(prompt)
metadata = {"agent_execution_id": self.agent_execution_id}
metadatas = []
for _ in chunk_response:
metadatas.append(metadata)
self.memory.add_texts(chunk_response, metadatas)
except Exception as exception:
logger.error(f"Exception: {exception}")
def handle_tool_response(self, session, assistant_reply):
"""Only handle processing of tool response"""
action = self.output_parser.parse(assistant_reply)
agent = session.query(Agent).filter(Agent.id == self.agent_config["agent_id"]).first()
organisation = agent.get_agent_organisation(session)
tool_executor = ToolExecutor(organisation_id=organisation.id, agent_id=agent.id, tools=self.tools, agent_execution_id=self.agent_execution_id)
return tool_executor.execute(session, action.name, action.args)
def _check_permission_in_restricted_mode(self, session, assistant_reply: str):
action = self.output_parser.parse(assistant_reply)
tools = {t.name: t for t in self.tools}
excluded_tools = [ToolExecutor.FINISH, '', None]
if self.agent_config["permission_type"].upper() == "RESTRICTED" and action.name not in excluded_tools and \
tools.get(action.name) and tools[action.name].permission_required:
new_agent_execution_permission = AgentExecutionPermission(
agent_execution_id=self.agent_execution_id,
status="PENDING",
agent_id=self.agent_config["agent_id"],
tool_name=action.name,
assistant_reply=assistant_reply)
session.add(new_agent_execution_permission)
session.commit()
return ToolExecutorResponse(is_permission_required=True, status="WAITING_FOR_PERMISSION",
permission_id=new_agent_execution_permission.id)
return ToolExecutorResponse(status="PENDING", is_permission_required=False)
def _check_for_completion(self, tool_response):
self.task_queue.complete_task(tool_response.result)
current_tasks = self.task_queue.get_tasks()
if self.task_queue.get_completed_tasks() and len(current_tasks) == 0:
tool_response.status = "COMPLETE"
if current_tasks and tool_response.status == "COMPLETE":
tool_response.status = "PENDING"
return tool_response
class TaskOutputHandler:
"""Handles the task output from the LLM. Output is mostly in the array of tasks and
handler adds every task to the task queue.
"""
def __init__(self, agent_execution_id: int, agent_config: dict):
self.agent_execution_id = agent_execution_id
self.task_queue = TaskQueue(str(agent_execution_id))
self.agent_config = agent_config
def handle(self, session, assistant_reply):
assistant_reply = JsonCleaner.extract_json_array_section(assistant_reply)
tasks = eval(assistant_reply)
tasks = np.array(tasks).flatten().tolist()
for task in reversed(tasks):
self.task_queue.add_task(task)
if len(tasks) > 0:
logger.info("Adding task to queue: " + str(tasks))
agent_execution = AgentExecution.find_by_id(session, self.agent_execution_id)
for task in tasks:
agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_execution_id,
agent_id=self.agent_config["agent_id"],
feed="New Task Added: " + task,
role="system",
feed_group_id=agent_execution.current_feed_group_id)
session.add(agent_execution_feed)
status = "COMPLETE" if len(self.task_queue.get_tasks()) == 0 else "PENDING"
session.commit()
return TaskExecutorResponse(status=status, retry=False)
class ReplaceTaskOutputHandler:
"""Handles the replace/prioritize task output type.
Output is mostly in the array of tasks and handler adds every task to the task queue.
"""
def __init__(self, agent_execution_id: int, agent_config: dict):
self.agent_execution_id = agent_execution_id
self.task_queue = TaskQueue(str(agent_execution_id))
self.agent_config = agent_config
def handle(self, session, assistant_reply):
assistant_reply = JsonCleaner.extract_json_array_section(assistant_reply)
tasks = eval(assistant_reply)
self.task_queue.clear_tasks()
for task in reversed(tasks):
self.task_queue.add_task(task)
if len(tasks) > 0:
logger.info("Tasks reprioritized in order: " + str(tasks))
status = "COMPLETE" if len(self.task_queue.get_tasks()) == 0 else "PENDING"
session.commit()
return TaskExecutorResponse(status=status, retry=False)
def get_output_handler(output_type: str, agent_execution_id: int, agent_config: dict, agent_tools: list = [],memory=None):
if output_type == "tools":
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools,memory=memory)
elif output_type == "replace_tasks":
return ReplaceTaskOutputHandler(agent_execution_id, agent_config)
elif output_type == "tasks":
return TaskOutputHandler(agent_execution_id, agent_config)
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools,memory=memory)
================================================
FILE: superagi/agent/output_parser.py
================================================
import json
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import re
import ast
import json
from superagi.helper.json_cleaner import JsonCleaner
from superagi.lib.logger import logger
class AgentGPTAction(NamedTuple):
name: str
args: Dict
class AgentTasks(NamedTuple):
tasks: List[str] = []
error: str = ""
class BaseOutputParser(ABC):
@abstractmethod
def parse(self, text: str) -> AgentGPTAction:
"""Return AgentGPTAction"""
class AgentSchemaOutputParser(BaseOutputParser):
"""Parses the output from the agent schema"""
def parse(self, response: str) -> AgentGPTAction:
if response.startswith("```") and response.endswith("```"):
response = "```".join(response.split("```")[1:-1])
response = JsonCleaner.extract_json_section(response)
# ast throws error if true/false params passed in json
response = JsonCleaner.clean_boolean(response)
# OpenAI returns `str(content_dict)`, literal_eval reverses this
try:
logger.debug("AgentSchemaOutputParser: ", response)
response_obj = ast.literal_eval(response)
args = response_obj['tool']['args'] if 'args' in response_obj['tool'] else {}
return AgentGPTAction(
name=response_obj['tool']['name'],
args=args,
)
except BaseException as e:
logger.info(f"AgentSchemaOutputParser: Error parsing JSON response {e}")
raise e
class AgentSchemaToolOutputParser(BaseOutputParser):
"""Parses the output from the agent schema for the tool"""
def parse(self, response: str) -> AgentGPTAction:
if response.startswith("```") and response.endswith("```"):
response = "```".join(response.split("```")[1:-1])
response = JsonCleaner.extract_json_section(response)
# ast throws error if true/false params passed in json
response = JsonCleaner.clean_boolean(response)
# OpenAI returns `str(content_dict)`, literal_eval reverses this
try:
logger.debug("AgentSchemaOutputParser: ", response)
response_obj = ast.literal_eval(response)
args = response_obj['args'] if 'args' in response_obj else {}
return AgentGPTAction(
name=response_obj['name'],
args=args,
)
except BaseException as e:
logger.info(f"AgentSchemaToolOutputParser: Error parsing JSON response {e}")
raise e
================================================
FILE: superagi/agent/prompts/agent_queue_input.txt
================================================
Use the below instruction and break down the last response to an individual array of items that can be inserted into the queue.
INSTRUCTION: `{instruction}`
Respond with an array of items that are JSON parsable and can be inserted into the queue.
Ignore the header row in the case of csv.
================================================
FILE: superagi/agent/prompts/agent_recursive_summary.txt
================================================
AI, you are provided with a previous summary of interactions between the system, user, and assistant, as well as additional conversations that were not included in the original summary.
If the previous summary is empty, your task is to create a summary based solely on the new interactions.
Previous Summary: {previous_ltm_summary}
{past_messages}
If the previous summary is not empty, your final summary should integrate the new interactions into the existing summary to create a comprehensive recap of all interactions.
If the previous summary is empty, your summary should encapsulate the main points of the new conversations.
In both cases, highlight the key issues discussed, decisions made, and any actions assigned.
Please ensure that the final summary does not exceed {char_limit} characters.
================================================
FILE: superagi/agent/prompts/agent_summary.txt
================================================
AI, your task is to generate a concise summary of the previous interactions between the system, user, and assistant.
The interactions are as follows:
{past_messages}
This summary should encapsulate the main points of the conversation, highlighting the key issues discussed, decisions made, and any actions assigned.
It should serve as a recap of the past interaction, providing a clear understanding of the conversation's context and outcomes.
Please ensure that the summary does not exceed {char_limit} characters.
================================================
FILE: superagi/agent/prompts/agent_tool_input.txt
================================================
{tool_name} is the most suitable tool for the given instruction, use {tool_name} to perform the below instruction which lets you achieve the high level goal.
High-Level GOAL:
`{goals}`
INSTRUCTION: `{instruction}`
Respond with tool name and tool arguments to achieve the instruction.
{tool_schema}
Respond with only valid JSON conforming to the following json schema. You should generate JSON as output and not JSON schema.
JSON Schema:
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "{tool_name}",
},
"args": {
"type": "object",
"description": "tool arguments",
}
},
"required": ["name", "args"]
}
================================================
FILE: superagi/agent/prompts/agent_tool_output.txt
================================================
Analyze {tool_name} output and follow the instruction to come up with the response:
High-Level GOAL:
`{goals}`
TOOL OUTPUT:
`{tool_output}`
INSTRUCTION: `{instruction}`
Analyze the instruction and respond with one of the below outputs. Response should be one of the below options:
{output_options}
================================================
FILE: superagi/agent/prompts/analyse_task.txt
================================================
High level goal:
{goals}
{task_instructions}
Your Current Task: `{current_task}`
Task History:
`{task_history}`
Based on this, your job is to understand the current task, pick out key parts, and think smart and fast.
Explain why you are doing each action, create a plan, and mention any worries you might have.
Ensure next action tool is picked from the below tool list.
TOOLS:
{tools}
Respond with only valid JSON conforming to the following schema:
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"thoughts": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": "short reasoning",
}
},
"required": ["reasoning"]
},
"tool": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "tool name",
},
"args": {
"type": "object",
"description": "tool arguments",
}
},
"required": ["name", "args"]
}
}
}
================================================
FILE: superagi/agent/prompts/create_tasks.txt
================================================
You are an AI assistant to create task.
High level goal:
{goals}
{task_instructions}
You have following incomplete tasks `{pending_tasks}`. You have following completed tasks `{completed_tasks}`.
Task History:
`{task_history}`
Based on this, create a single task in plain english to be completed by your AI system ONLY IF REQUIRED to get closer to or fully reach your high level goal.
Don't create any task if it is already covered in incomplete or completed tasks.
Ensure your new task are not deviated from completing the goal.
Your answer should be an array of tasks in plain english that can be used with JSON.parse() and NOTHING ELSE. Return empty array if no new task is required.
================================================
FILE: superagi/agent/prompts/initialize_tasks.txt
================================================
You are a task-generating AI known as SuperAGI. You are not a part of any system or device. Your role is to understand the goals presented to you, identify important components, Go through the instruction provided by the user and construct a thorough execution plan.
GOALS:
{goals}
{task_instructions}
Construct a sequence of actions, not exceeding 3 steps, to achieve this goal.
Submit your response as a formatted ARRAY of strings, suitable for utilization with JSON.parse().
================================================
FILE: superagi/agent/prompts/prioritize_tasks.txt
================================================
You are a task prioritization AI assistant.
High level goal:
{goals}
{task_instructions}
You have following incomplete tasks `{pending_tasks}`. You have following completed tasks `{completed_tasks}`.
Based on this, evaluate the incomplete tasks and sort them in the order of execution. In output first task will be executed first and so on.
Remove if any tasks are unnecessary or duplicate incomplete tasks. Remove tasks if they are already covered in completed tasks.
Remove tasks if it does not help in achieving the main goal.
Your answer should be an array of strings that can be used with JSON.parse() and NOTHING ELSE.
================================================
FILE: superagi/agent/prompts/superagi.txt
================================================
You are SuperAGI an AI assistant to solve complex problems. Your decisions must always be made independently without seeking user assistance.
Play to your strengths as an LLM and pursue simple strategies with no legal complications.
If you have completed all your tasks or reached end state, make sure to use the "finish" tool.
GOALS:
{goals}
{instructions}
CONSTRAINTS:
{constraints}
TOOLS:
{tools}
PERFORMANCE EVALUATION:
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
2. Use instruction to decide the flow of execution and decide the next steps for achieving the task.
3. Constructively self-criticize your big-picture behavior constantly.
4. Reflect on past decisions and strategies to refine your approach.
5. Every tool has a cost, so be smart and efficient.
Respond with only valid JSON conforming to the following schema:
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"thoughts": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "thought"
},
"reasoning": {
"type": "string",
"description": "short reasoning"
},
"plan": {
"type": "string",
"description": "- short bulleted\n- list that conveys\n- long-term plan"
},
"criticism": {
"type": "string",
"description": "constructive self-criticism"
},
"speak": {
"type": "string",
"description": "thoughts summary to say to user"
}
},
"required": ["text", "reasoning", "plan", "criticism", "speak"],
"additionalProperties": false
},
"tool": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "tool name"
},
"args": {
"type": "object",
"description": "tool arguments"
}
},
"required": ["name", "args"],
"additionalProperties": false
}
},
"required": ["thoughts", "tool"],
"additionalProperties": false
}
================================================
FILE: superagi/agent/queue_step_handler.py
================================================
import time
import numpy as np
from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.agent.task_queue import TaskQueue
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.json_cleaner import JsonCleaner
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_tool import AgentWorkflowStepTool
from superagi.models.agent import Agent
from superagi.types.queue_status import QueueStatus
class QueueStepHandler:
"""Handles the queue step of the agent workflow"""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int):
self.session = session
self.llm = llm
self.agent_execution_id = agent_execution_id
self.agent_id = agent_id
self.organisation = Agent.find_org_by_agent_id(self.session, agent_id=self.agent_id)
def _queue_identifier(self, step_tool):
return step_tool.unique_id + "_" + str(self.agent_execution_id)
def _build_task_queue(self, step_tool):
return TaskQueue(self._queue_identifier(step_tool))
def execute_step(self):
execution = AgentExecution.get_agent_execution_from_id(self.session, self.agent_execution_id)
workflow_step = AgentWorkflowStep.find_by_id(self.session, execution.current_agent_step_id)
step_tool = AgentWorkflowStepTool.find_by_id(self.session, workflow_step.action_reference_id)
task_queue = self._build_task_queue(step_tool)
if not task_queue.get_status() or task_queue.get_status() == QueueStatus.COMPLETE.value:
task_queue.set_status(QueueStatus.INITIATED.value)
if task_queue.get_status() == QueueStatus.INITIATED.value:
self._add_to_queue(task_queue, step_tool)
execution.current_feed_group_id = "DEFAULT"
task_queue.set_status(QueueStatus.PROCESSING.value)
if not task_queue.get_tasks():
task_queue.set_status(QueueStatus.COMPLETE.value)
return "COMPLETE"
self._consume_from_queue(task_queue)
return "default"
def _add_to_queue(self, task_queue: TaskQueue, step_tool: AgentWorkflowStepTool):
assistant_reply = self._process_input_instruction(step_tool)
self._process_reply(task_queue, assistant_reply)
def _consume_from_queue(self, task_queue: TaskQueue):
tasks = task_queue.get_tasks()
agent_execution = AgentExecution.find_by_id(self.session, self.agent_execution_id)
if tasks:
task = task_queue.get_first_task()
# generating the new feed group id
agent_execution.current_feed_group_id = "GROUP_" + str(int(time.time()))
self.session.commit()
task_response_feed = AgentExecutionFeed(agent_execution_id=self.agent_execution_id,
agent_id=self.agent_id,
feed="Input: " + task,
role="assistant",
feed_group_id=agent_execution.current_feed_group_id)
self.session.add(task_response_feed)
self.session.commit()
task_queue.complete_task("PROCESSED")
def _process_reply(self, task_queue: TaskQueue, assistant_reply: str):
assistant_reply = JsonCleaner.extract_json_array_section(assistant_reply)
print("Queue reply:", assistant_reply)
task_array = np.array(eval(assistant_reply)).flatten().tolist()
for task in task_array:
task_queue.add_task(str(task))
logger.info("RAMRAM: Added task to queue: ", task)
def _process_input_instruction(self, step_tool):
prompt = self._build_queue_input_prompt(step_tool)
logger.info("Prompt: ", prompt)
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
print(".........//////////////..........2")
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled,
completion_prompt=step_tool.completion_prompt)
current_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
response = self.llm.chat_completion(messages, TokenCounter(session=self.session, organisation_id=self.organisation.id).token_limit(self.llm.get_model()) - current_tokens)
if 'error' in response and response['message'] is not None:
ErrorHandler.handle_openai_errors(self.session, self.agent_id, self.agent_execution_id, response['message'])
if 'content' not in response or response['content'] is None:
raise RuntimeError(f"Failed to get response from llm")
total_tokens = current_tokens + TokenCounter.count_message_tokens(response, self.llm.get_model())
AgentExecution.update_tokens(self.session, self.agent_execution_id, total_tokens)
assistant_reply = response['content']
return assistant_reply
def _build_queue_input_prompt(self, step_tool: AgentWorkflowStepTool):
queue_input_prompt = PromptReader.read_agent_prompt(__file__, "agent_queue_input.txt")
queue_input_prompt = queue_input_prompt.replace("{instruction}", step_tool.input_instruction)
return queue_input_prompt
================================================
FILE: superagi/agent/task_queue.py
================================================
import json
import redis
from superagi.config.config import get_config
redis_url = get_config('REDIS_URL') or "localhost:6379"
"""TaskQueue manages current tasks and past tasks in Redis """
class TaskQueue:
def __init__(self, queue_name: str):
self.queue_name = queue_name + "_q"
self.completed_tasks = queue_name + "_q_completed"
self.db = redis.Redis.from_url("redis://" + redis_url + "/0", decode_responses=True)
def add_task(self, task: str):
self.db.lpush(self.queue_name, task)
# print("Added task. New tasks:", str(self.get_tasks()))
def complete_task(self, response):
if len(self.get_tasks()) <= 0:
return
task = self.db.lpop(self.queue_name)
self.db.lpush(self.completed_tasks, str({"task": task, "response": response}))
def get_first_task(self):
return self.db.lindex(self.queue_name, 0)
def get_tasks(self):
return self.db.lrange(self.queue_name, 0, -1)
def get_completed_tasks(self):
tasks = self.db.lrange(self.completed_tasks, 0, -1)
return [eval(task) for task in tasks]
def clear_tasks(self):
self.db.delete(self.queue_name)
def get_last_task_details(self):
response = self.db.lindex(self.completed_tasks, 0)
if response is None:
return None
return eval(response)
def set_status(self, status):
self.db.set(self.queue_name + "_status", status)
def get_status(self):
return self.db.get(self.queue_name + "_status")
================================================
FILE: superagi/agent/tool_builder.py
================================================
import importlib
import os
from superagi.config.config import get_config
from superagi.llms.llm_model_factory import get_model
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
from superagi.models.agent import Agent
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseToolkitConfiguration
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
class DBToolkitConfiguration(BaseToolkitConfiguration):
session = None
toolkit_id: int
def __init__(self, session=None, toolkit_id=None):
self.session = session
self.toolkit_id = toolkit_id
def get_tool_config(self, key: str):
tool_config = self.session.query(ToolConfig).filter_by(key=key, toolkit_id=self.toolkit_id).first()
if tool_config and tool_config.value:
if is_encrypted(tool_config.value):
return decrypt_data(tool_config.value)
else:
return tool_config.value
return super().get_tool_config(key=key)
class ToolBuilder:
def __init__(self, session, agent_id: int, agent_execution_id: int = None):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id
def __validate_filename(self, filename):
"""
Validate the filename by removing the last three characters if the filename ends with ".py".
Args:
filename (str): The filename.
Returns:
str: The validated filename.
"""
if filename.endswith(".py"):
return filename[:-3] # Remove the last three characters (i.e., ".py")
return filename
def build_tool(self, tool: Tool):
"""
Create an object of a agent usable tool dynamically.
Args:
tool (Tool) : Tool object from which agent tool would be made.
Returns:
object: The object of the agent usable tool.
"""
file_name = self.__validate_filename(filename=tool.file_name)
tools_dir=""
tool_paths = ["superagi/tools", "superagi/tools/external_tools", "superagi/tools/marketplace_tools"]
for tool_path in tool_paths:
if os.path.exists(os.path.join(os.getcwd(), tool_path) + '/' + tool.folder_name):
tools_dir = tool_path
break
parsed_tools_dir = tools_dir.rstrip("/")
module_name = ".".join(parsed_tools_dir.split("/") + [tool.folder_name, file_name])
# module_name = f"superagi.tools.{folder_name}.{file_name}"
# Load the module dynamically
module = importlib.import_module(module_name)
# Get the class from the loaded module
obj_class = getattr(module, tool.class_name)
# Create an instance of the class
new_object = obj_class()
new_object.toolkit_config = DBToolkitConfiguration(session=self.session, toolkit_id=tool.toolkit_id)
return new_object
def set_default_params_tool(self, tool, agent_config, agent_execution_config, model_api_key: str,
resource_summary: str = "",memory=None):
"""
Set the default parameters for the tools.
Args:
tool : Tool object.
agent_config (dict): Parsed agent configuration.
agent_execution_config (dict): Parsed execution configuration
agent_id (int): The ID of the agent.
model_api_key (str): The API key of the model
Returns:
list: The list of tools with default parameters.
"""
organisation = Agent.find_org_by_agent_id(self.session, agent_id=agent_config['agent_id'])
if hasattr(tool, 'goals'):
tool.goals = agent_execution_config["goal"]
if hasattr(tool, 'instructions'):
tool.instructions = agent_execution_config["instruction"]
if hasattr(tool, 'llm') and (agent_config["model"] == "gpt4" or agent_config[
"model"] == "gpt-3.5-turbo") and tool.name != "QueryResource":
tool.llm = get_model(model="gpt-3.5-turbo", api_key=model_api_key, organisation_id=organisation.id , temperature=0.4)
elif hasattr(tool, 'llm'):
tool.llm = get_model(model=agent_config["model"], api_key=model_api_key, organisation_id=organisation.id, temperature=0.4)
if hasattr(tool, 'agent_id'):
tool.agent_id = self.agent_id
if hasattr(tool, 'agent_execution_id'):
tool.agent_execution_id = self.agent_execution_id
if hasattr(tool, 'resource_manager'):
tool.resource_manager = FileManager(session=self.session, agent_id=self.agent_id,
agent_execution_id=self.agent_execution_id)
if hasattr(tool, 'tool_response_manager'):
tool.tool_response_manager = ToolResponseQueryManager(session=self.session,
agent_execution_id=self.agent_execution_id,memory=memory)
if tool.name == "QueryResourceTool":
tool.description = tool.description.replace("{summary}", resource_summary)
return tool
================================================
FILE: superagi/agent/tool_executor.py
================================================
from pydantic import ValidationError
from superagi.agent.common_types import ToolExecutorResponse
from superagi.apm.event_handler import EventHandler
from superagi.lib.logger import logger
class ToolExecutor:
"""Executes the tool with the given args."""
FINISH = "finish"
def __init__(self, organisation_id: int, agent_id: int, tools: list, agent_execution_id: int):
self.organisation_id = organisation_id
self.agent_id = agent_id
self.tools = tools
self.agent_execution_id = agent_execution_id
def execute(self, session, tool_name, tool_args):
"""Executes the tool with the given args.
Args:
session (Session): The database session.
tool_name (str): The name of the tool to execute.
tool_args (dict): The arguments to pass to the tool.
"""
tools = {t.name.lower().replace(" ", ""): t for t in self.tools}
tool_name = tool_name.lower().replace(" ", "")
if tool_name == ToolExecutor.FINISH or tool_name == "":
logger.info("\nTask Finished :) \n")
return ToolExecutorResponse(status="COMPLETE", result="")
if tool_name in tools.keys():
status = "SUCCESS"
tool = tools[tool_name]
retry = False
EventHandler(session=session).create_event('tool_used', {'tool_name': tool.name, 'agent_execution_id': self.agent_execution_id}, self.agent_id,
self.organisation_id),
try:
parsed_args = self.clean_tool_args(tool_args)
observation = tool.execute(parsed_args)
except ValidationError as e:
status = "ERROR"
retry = True
observation = (
f"Validation Error in args: {str(e)}, args: {tool_args}"
)
except Exception as e:
status = "ERROR"
retry = True
observation = (
f"Error1: {str(e)}, {type(e).__name__}, args: {tool_args}"
)
output = ToolExecutorResponse(status=status, result=f"Tool {tool.name} returned: {observation}",
retry=retry)
elif tool_name == "ERROR":
output = ToolExecutorResponse(status="ERROR", result=f"Error Tool Name: {tool_args}. ", retry=False)
else:
result = (
f"Unknown tool '{tool_name}'. "
f"Please refer to the 'TOOLS' list for available "
f"tools and only respond in the specified JSON format."
)
output = ToolExecutorResponse(status="ERROR", result=result, retry=True)
logger.info("Tool Response : " + str(output) + "\n")
return output
def clean_tool_args(self, args):
parsed_args = {}
for key in args.keys():
parsed_args[key] = args[key]
if type(args[key]) is dict and "value" in args[key]:
parsed_args[key] = args[key]["value"]
return parsed_args
================================================
FILE: superagi/agent/types/__init__.py
================================================
================================================
FILE: superagi/agent/types/agent_execution_status.py
================================================
from enum import Enum
class AgentExecutionStatus(Enum):
RUNNING = 'RUNNING'
WAITING_FOR_PERMISSION = 'WAITING_FOR_PERMISSION'
ITERATION_LIMIT_EXCEEDED = 'ITERATION_LIMIT_EXCEEDED'
WAIT_STEP = 'WAIT_STEP'
COMPLETED = 'COMPLETED'
@classmethod
def get_agent_execution_status(cls, store):
if store is None:
raise ValueError("Storage type cannot be None.")
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid storage name.")
================================================
FILE: superagi/agent/types/agent_workflow_step_action_types.py
================================================
from enum import Enum
class AgentWorkflowStepAction(Enum):
ITERATION_WORKFLOW = 'ITERATION_WORKFLOW'
TOOL = 'TOOL'
WAIT_STEP = 'WAIT_STEP'
@classmethod
def get_agent_workflow_action_type(cls, store):
if store is None:
raise ValueError("Storage type cannot be None.")
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid storage name.")
================================================
FILE: superagi/agent/types/wait_step_status.py
================================================
from enum import Enum
class AgentWorkflowStepWaitStatus(Enum):
PENDING = 'PENDING'
WAITING = 'WAITING'
COMPLETED = 'COMPLETED'
@classmethod
def get_agent_workflow_step_wait_status(cls, store):
if store is None:
raise ValueError("Storage type cannot be None.")
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid storage name.")
================================================
FILE: superagi/agent/workflow_seed.py
================================================
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.agent.agent_prompt_template import AgentPromptTemplate
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
from superagi.tools.apollo.apollo_search import ApolloSearchTool
from superagi.tools.code.write_code import CodingTool
from superagi.tools.code.write_spec import WriteSpecTool
from superagi.tools.code.write_test import WriteTestTool
from superagi.tools.email.read_email import ReadEmailTool
from superagi.tools.email.send_email import SendEmailTool
from superagi.tools.file.append_file import AppendFileTool
from superagi.tools.file.list_files import ListFileTool
from superagi.tools.file.read_file import ReadFileTool
from superagi.tools.file.write_file import WriteFileTool
from superagi.tools.github.add_file import GithubAddFileTool
from superagi.tools.google_calendar.create_calendar_event import CreateEventCalendarTool
from superagi.tools.google_calendar.google_calendar_toolkit import GoogleCalendarToolKit
from superagi.tools.google_search.google_search import GoogleSearchTool
from superagi.tools.jira.create_issue import CreateIssueTool
from superagi.tools.searx.searx import SearxSearchTool
from superagi.tools.slack.send_message import SlackMessageTool
from superagi.tools.thinking.tools import ThinkingTool
from superagi.tools.twitter.send_tweets import SendTweetsTool
from superagi.tools.webscaper.tools import WebScraperTool
class AgentWorkflowSeed:
@classmethod
def build_sales_workflow(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Sales Engagement Workflow",
"Sales Engagement Workflow")
# step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
# str(agent_workflow.id) + "_step1",
# ApolloSearchTool().name,
# "Search for leads based on the given goals",
# step_type="TRIGGER")
#
step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step2",
ListFileTool().name,
"list the files",
step_type="TRIGGER")
step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step3",
ReadFileTool().name,
"Read the leads from the file")
# task queue ends when the elements gets over
step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step4",
"TASK_QUEUE",
"Break the above response array of items",
completion_prompt="Get array of items from the above response. Array should suitable utilization of JSON.parse().")
step5 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step5",
GoogleSearchTool().name,
"Search about the company in which the lead is working")
step6 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step6",
"WAIT_FOR_PERMISSION",
"Email will be based on this content. Do you want send the email?")
step7 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step7",
SearxSearchTool().name,
"Search about the company given in the high-end goal only")
step8 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step8",
SendEmailTool().name,
"Customize the Email according to the company information in the mail")
step9 = AgentWorkflowStep.find_or_create_wait_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step9",
"Wait for 2 minutes",
2*60)
step10 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step10",
ReadEmailTool().name,
"Read the email from adarshdeepmurari@gmail.com")
step11 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step11",
SendEmailTool().name,
"Customize the Email according to the company information in the mail")
# AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id)
AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id)
AgentWorkflowStep.add_next_workflow_step(session, step4.id, -1, "COMPLETE")
AgentWorkflowStep.add_next_workflow_step(session, step4.id, step5.id)
AgentWorkflowStep.add_next_workflow_step(session, step5.id, step6.id)
AgentWorkflowStep.add_next_workflow_step(session, step6.id, step7.id, "YES")
AgentWorkflowStep.add_next_workflow_step(session, step6.id, step5.id, "NO")
AgentWorkflowStep.add_next_workflow_step(session, step7.id, step8.id)
AgentWorkflowStep.add_next_workflow_step(session, step8.id, step9.id)
AgentWorkflowStep.add_next_workflow_step(session, step9.id, step10.id)
AgentWorkflowStep.add_next_workflow_step(session, step10.id, step11.id)
AgentWorkflowStep.add_next_workflow_step(session, step11.id, step4.id)
session.commit()
@classmethod
def build_recruitment_workflow(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Recruitment Workflow",
"Recruitment Workflow")
step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
ListFileTool().name,
"List the files from the resource manager",
step_type="TRIGGER")
# task queue ends when the elements gets over
step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step2",
"TASK_QUEUE",
"Break the above response array of items",
completion_prompt="Get array of items from the above response. Array should suitable utilization of JSON.parse(). Skip job_description file from list.")
step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step3",
ReadFileTool().name,
"Read the resume from above input",
"Check if the resume matches High-Level GOAL")
step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step4",
SendEmailTool().name,
"Write a custom acceptance Email to the candidates")
step5 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step5",
SendEmailTool().name,
"Write a custom Reject Email to the candidates")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE")
AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id, "YES")
AgentWorkflowStep.add_next_workflow_step(session, step3.id, step5.id, "NO")
AgentWorkflowStep.add_next_workflow_step(session, step4.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step5.id, step2.id)
session.commit()
@classmethod
def build_coding_workflow(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "SuperCoder", "SuperCoder")
step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
WriteSpecTool().name,
"Spec description",
step_type="TRIGGER")
step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step2",
WriteTestTool().name,
"Test description")
step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step3",
CodingTool().name,
"Code description")
step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step4",
"WAIT_FOR_PERMISSION",
"Your code is ready. Do you want end?")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id)
AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id)
AgentWorkflowStep.add_next_workflow_step(session, step4.id, -1, "YES")
AgentWorkflowStep.add_next_workflow_step(session, step4.id, step3.id, "NO")
@classmethod
def build_goal_based_agent(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Goal Based Workflow", "Goal Based Workflow")
step1 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
"Goal Based Agent-I", step_type="TRIGGER")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step1.id)
AgentWorkflowStep.add_next_workflow_step(session, step1.id, -1, "COMPLETE")
@classmethod
def build_task_based_agent(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Dynamic Task Workflow", "Dynamic Task Workflow")
step1 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
"Initialize Tasks-I", step_type="TRIGGER")
step2 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step2",
"Dynamic Task Queue-I", step_type="NORMAL")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE")
@classmethod
def build_fixed_task_based_agent(cls, session):
agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Fixed Task Workflow", "Fixed Task Workflow")
step1 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step1",
"Initialize Tasks-I", step_type="TRIGGER")
step2 = AgentWorkflowStep.find_or_create_iteration_workflow_step(session, agent_workflow.id,
str(agent_workflow.id) + "_step2",
"Fixed Task Queue-I", step_type="NORMAL")
AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, step2.id)
AgentWorkflowStep.add_next_workflow_step(session, step2.id, -1, "COMPLETE")
class IterationWorkflowSeed:
@classmethod
def build_single_step_agent(cls, session):
iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Goal Based Agent-I", "Goal Based Agent")
output = AgentPromptTemplate.get_super_agi_single_prompt()
IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "gb1",
output["prompt"],
str(output["variables"]), "TRIGGER", "tools",
history_enabled=True,
completion_prompt="Determine which next tool to use, and respond using the format specified above:")
@classmethod
def build_task_based_agents(cls, session):
iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Dynamic Task Queue-I",
"Dynamic Task Queue", has_task_queue=True)
output = AgentPromptTemplate.analyse_task()
workflow_step1 = IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "tb1",
output["prompt"],
str(output["variables"]), "TRIGGER", "tools")
output = AgentPromptTemplate.create_tasks()
workflow_step2 = IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "tb2",
output["prompt"],
str(output["variables"]), "NORMAL", "tasks")
output = AgentPromptTemplate.prioritize_tasks()
workflow_step3 = IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "tb3",
output["prompt"],
str(output["variables"]), "NORMAL", "replace_tasks")
workflow_step1.next_step_id = workflow_step2.id
workflow_step2.next_step_id = workflow_step3.id
session.commit()
@classmethod
def build_initialize_task_workflow(cls, session):
iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Initialize Tasks-I", "Initialize Tasks",
has_task_queue=True)
output = AgentPromptTemplate.start_task_based()
IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "init_task1",
output["prompt"], str(output["variables"]), "TRIGGER", "tasks")
@classmethod
def build_action_based_agents(cls, session):
iteration_workflow = IterationWorkflow.find_or_create_by_name(session, "Fixed Task Queue-I", "Fixed Task Queue",
has_task_queue=True)
output = AgentPromptTemplate.analyse_task()
IterationWorkflowStep.find_or_create_step(session, iteration_workflow.id, "ab1",
output["prompt"], str(output["variables"]), "TRIGGER", "tools")
================================================
FILE: superagi/apm/__init__.py
================================================
================================================
FILE: superagi/apm/analytics_helper.py
================================================
from typing import List, Dict, Union, Any
from sqlalchemy import text, func, and_
from sqlalchemy.orm import Session
from superagi.models.events import Event
class AnalyticsHelper:
def __init__(self, session: Session, organisation_id: int):
self.session = session
self.organisation_id = organisation_id
def calculate_run_completed_metrics(self) -> Dict[str, Dict[str, Union[int, List[Dict[str, int]]]]]:
agent_model_query = self.session.query(
Event.event_property['model'].label('model'),
Event.agent_id
).filter_by(event_name="agent_created", org_id=self.organisation_id).subquery()
agent_runs_query = self.session.query(
agent_model_query.c.model,
func.count(Event.id).label('runs')
).join(Event, and_(Event.agent_id == agent_model_query.c.agent_id, Event.org_id == self.organisation_id)).filter(Event.event_name.in_(['run_completed', 'run_iteration_limit_crossed'])).group_by(agent_model_query.c.model).subquery()
agent_tokens_query = self.session.query(
agent_model_query.c.model,
func.sum(text("(event_property->>'tokens_consumed')::int")).label('tokens')
).join(Event, and_(Event.agent_id == agent_model_query.c.agent_id, Event.org_id == self.organisation_id)).filter(Event.event_name.in_(['run_completed', 'run_iteration_limit_crossed'])).group_by(agent_model_query.c.model).subquery()
agent_count_query = self.session.query(
agent_model_query.c.model,
func.count(agent_model_query.c.agent_id).label('agents')
).group_by(agent_model_query.c.model).subquery()
agents = self.session.query(agent_count_query).all()
runs = self.session.query(agent_runs_query).all()
tokens = self.session.query(agent_tokens_query).all()
metrics = {
'agent_details': {
'total_agents': sum([item.agents for item in agents]),
'model_metrics': [{'name': item.model, 'value': item.agents} for item in agents]
},
'run_details': {
'total_runs': sum([item.runs for item in runs]),
'model_metrics': [{'name': item.model, 'value': item.runs} for item in runs]
},
'tokens_details': {
'total_tokens': sum([item.tokens for item in tokens]),
'model_metrics': [{'name': item.model, 'value': item.tokens} for item in tokens]
},
}
return metrics
def fetch_agent_data(self) -> Dict[str, List[Dict[str, Any]]]:
agent_subquery = self.session.query(
Event.agent_id,
Event.event_property['agent_name'].label('agent_name'),
Event.event_property['model'].label('model')
).filter_by(event_name="agent_created", org_id=self.organisation_id).subquery()
run_subquery = self.session.query(
Event.agent_id,
func.sum(text("(event_property->>'tokens_consumed')::int")).label('total_tokens'),
func.sum(text("(event_property->>'calls')::int")).label('total_calls'),
func.count(Event.id).label('runs_completed'),
).filter(and_(Event.event_name.in_(['run_completed', 'run_iteration_limit_crossed']), Event.org_id == self.organisation_id)).group_by(Event.agent_id).subquery()
tool_subquery = self.session.query(
Event.agent_id,
func.array_agg(Event.event_property['tool_name'].distinct()).label('tools_used'),
).filter_by(event_name="tool_used", org_id=self.organisation_id).group_by(Event.agent_id).subquery()
start_time_subquery = self.session.query(
Event.agent_id,
Event.event_property['agent_execution_id'].label('agent_execution_id'),
func.min(func.extract('epoch', Event.created_at)).label('start_time')
).filter_by(event_name="run_created", org_id=self.organisation_id).group_by(Event.agent_id, Event.event_property['agent_execution_id']).subquery()
end_time_subquery = self.session.query(
Event.agent_id,
Event.event_property['agent_execution_id'].label('agent_execution_id'),
func.max(func.extract('epoch', Event.created_at)).label('end_time')
).filter(and_(Event.event_name.in_(['run_completed', 'run_iteration_limit_crossed']), Event.org_id == self.organisation_id)).group_by(Event.agent_id, Event.event_property['agent_execution_id']).subquery()
time_diff_subquery = self.session.query(
start_time_subquery.c.agent_id,
(func.avg(end_time_subquery.c.end_time - start_time_subquery.c.start_time)).label('avg_run_time')
).join(end_time_subquery, start_time_subquery.c.agent_execution_id == end_time_subquery.c.agent_execution_id). \
group_by(start_time_subquery.c.agent_id).subquery()
query = self.session.query(
agent_subquery.c.agent_id,
agent_subquery.c.agent_name,
agent_subquery.c.model,
run_subquery.c.total_tokens,
run_subquery.c.total_calls,
run_subquery.c.runs_completed,
tool_subquery.c.tools_used,
time_diff_subquery.c.avg_run_time
).outerjoin(run_subquery, run_subquery.c.agent_id == agent_subquery.c.agent_id) \
.outerjoin(tool_subquery, tool_subquery.c.agent_id == agent_subquery.c.agent_id) \
.outerjoin(time_diff_subquery, time_diff_subquery.c.agent_id == agent_subquery.c.agent_id)
result = query.all()
agent_details = [{
"name": row.agent_name,
"agent_id": row.agent_id,
"runs_completed": row.runs_completed if row.runs_completed else 0,
"total_calls": row.total_calls if row.total_calls else 0,
"total_tokens": row.total_tokens if row.total_tokens else 0,
"tools_used": row.tools_used,
"model_name": row.model,
"avg_run_time": row.avg_run_time if row.avg_run_time else 0,
} for row in result]
return {'agent_details': agent_details}
def fetch_agent_runs(self, agent_id: int) -> List[Dict[str, int]]:
agent_runs = []
completed_subquery = self.session.query(
Event.event_property['agent_execution_id'].label('completed_agent_execution_id'),
Event.event_property['tokens_consumed'].label('tokens_consumed'),
Event.event_property['calls'].label('calls'),
Event.updated_at
).filter(Event.event_name.in_(['run_completed','run_iteration_limit_crossed']), Event.agent_id == agent_id, Event.org_id == self.organisation_id).subquery()
created_subquery = self.session.query(
Event.event_property['agent_execution_id'].label('created_agent_execution_id'),
Event.event_property['agent_execution_name'].label('agent_execution_name'),
Event.created_at
).filter(Event.event_name == "run_created", Event.agent_id == agent_id, Event.org_id == self.organisation_id).subquery()
query = self.session.query(
created_subquery.c.agent_execution_name,
completed_subquery.c.tokens_consumed,
completed_subquery.c.calls,
created_subquery.c.created_at,
completed_subquery.c.updated_at
).join(completed_subquery, completed_subquery.c.completed_agent_execution_id == created_subquery.c.created_agent_execution_id)
result = query.all()
agent_runs = [{
'name': row.agent_execution_name,
'tokens_consumed': int(row.tokens_consumed) if row.tokens_consumed else 0,
'calls': int(row.calls) if row.calls else 0,
'created_at': row.created_at,
'updated_at': row.updated_at
} for row in result]
return agent_runs
def get_active_runs(self) -> List[Dict[str, str]]:
running_executions = []
end_event_subquery = self.session.query(
Event.event_property['agent_execution_id'].label('agent_execution_id'),
).filter(
Event.event_name.in_(['run_completed', 'run_iteration_limit_crossed']),
Event.org_id == self.organisation_id
).subquery()
start_subquery = self.session.query(
Event.event_property['agent_execution_id'].label('agent_execution_id'),
Event.event_property['agent_execution_name'].label('agent_execution_name'),
Event.created_at,
Event.agent_id
).filter_by(event_name="run_created", org_id = self.organisation_id).subquery()
agent_created_subquery = self.session.query(
Event.event_property['agent_name'].label('agent_name'),
Event.agent_id
).filter_by(event_name="agent_created", org_id = self.organisation_id).subquery()
query = self.session.query(
start_subquery.c.agent_execution_name,
start_subquery.c.created_at,
agent_created_subquery.c.agent_name
).select_from(start_subquery)
query = query.outerjoin(end_event_subquery, start_subquery.c.agent_execution_id == end_event_subquery.c.agent_execution_id).filter(end_event_subquery.c.agent_execution_id == None)
query = query.join(agent_created_subquery, start_subquery.c.agent_id == agent_created_subquery.c.agent_id)
result = query.all()
running_executions = [{
'name': row.agent_execution_name,
'created_at': row.created_at,
'agent_name': row.agent_name or 'Unknown',
} for row in result]
return running_executions
================================================
FILE: superagi/apm/call_log_helper.py
================================================
import logging
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy import func, distinct
from superagi.models.call_logs import CallLogs
from superagi.models.agent import Agent
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
class CallLogHelper:
def __init__(self, session: Session, organisation_id: int):
self.session = session
self.organisation_id = organisation_id
def create_call_log(self, agent_execution_name: str, agent_id: int, tokens_consumed: int, tool_used: str, model: str) -> Optional[CallLogs]:
try:
call_log = CallLogs(
agent_execution_name=agent_execution_name,
agent_id=agent_id,
tokens_consumed=tokens_consumed,
tool_used=tool_used,
model=model,
org_id=self.organisation_id,
)
self.session.add(call_log)
self.session.commit()
return call_log
except SQLAlchemyError as err:
logging.error(f"Error while creating call log: {str(err)}")
return None
def fetch_data(self, model: str):
try:
result = self.session.query(
func.sum(CallLogs.tokens_consumed),
func.count(CallLogs.id),
func.count(distinct(CallLogs.agent_id))
).filter(CallLogs.model == model, CallLogs.org_id == self.organisation_id).first()
if result is None:
return None
model_data = {
'model': model,
'total_tokens': result[0],
'total_calls': result[1],
'total_agents': result[2],
'runs': []
}
runs = self.session.query(CallLogs).filter(CallLogs.model == model,
CallLogs.org_id == self.organisation_id).all()
run_agent_ids = [run.agent_id for run in runs]
agents = self.session.query(Agent).filter(Agent.id.in_(run_agent_ids)).all()
agent_id_name_map = {agent.id: agent.name for agent in agents}
tools_used = [run.tool_used for run in runs]
toolkit_ids_allowed = self.session.query(Toolkit.id).filter(Toolkit.organisation_id == self.organisation_id).all()
toolkit_ids_allowed = [toolkit_id[0] for toolkit_id in toolkit_ids_allowed]
tools = self.session.query(Tool).filter(Tool.name.in_(tools_used), Tool.toolkit_id.in_(toolkit_ids_allowed))\
.all()
tools_name_toolkit_id_map = {tool.name: tool.toolkit_id for tool in tools}
for run in runs:
model_data['runs'].append({
'id': run.id,
'agent_execution_name': run.agent_execution_name,
'agent_id': run.agent_id,
'agent_name': agent_id_name_map[run.agent_id] if run.agent_id in agent_id_name_map else None,
'tokens_consumed': run.tokens_consumed,
'tool_used': run.tool_used,
'toolkit_name': tools_name_toolkit_id_map[run.tool_used] if run.tool_used in tools_name_toolkit_id_map else None,
'org_id': run.org_id,
'created_at': run.created_at,
'updated_at': run.updated_at,
})
model_data['runs'] = model_data['runs'][::-1]
return model_data
except SQLAlchemyError as err:
logging.error(f"Error while fetching call log data: {str(err)}")
return None
================================================
FILE: superagi/apm/event_handler.py
================================================
import logging
from typing import Optional, Dict
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from superagi.models.events import Event
class EventHandler:
def __init__(self, session: Session):
self.session = session
def create_event(self, event_name: str, event_property: Dict, agent_id: int,
org_id: int, event_value: int = 1) -> Optional[Event]:
try:
event = Event(
event_name=event_name,
event_value=event_value,
event_property=event_property,
agent_id=agent_id,
org_id=org_id,
)
self.session.add(event)
self.session.commit()
return event
except SQLAlchemyError as err:
logging.error(f"Error while creating event: {str(err)}")
return None
================================================
FILE: superagi/apm/knowledge_handler.py
================================================
from sqlalchemy.orm import Session
from superagi.models.events import Event
from superagi.models.knowledges import Knowledges
from sqlalchemy import Integer, or_, label, case, and_
from fastapi import HTTPException
from typing import List, Dict, Union, Any
from sqlalchemy.sql import func
from sqlalchemy.orm import aliased
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution_config import AgentExecutionConfiguration
import pytz
from datetime import datetime
class KnowledgeHandler:
def __init__(self, session: Session, organisation_id: int):
self.session = session
self.organisation_id = organisation_id
def get_knowledge_usage_by_name(self, knowledge_name: str) -> Dict[str, Dict[str, int]]:
is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first()
if not is_knowledge_valid:
raise HTTPException(status_code=404, detail="Knowledge not found")
EventAlias = aliased(Event)
knowledge_used_event = self.session.query(
Event.event_property['knowledge_name'].label('knowledge_name'),
func.count(Event.agent_id.distinct()).label('knowledge_unique_agents')
).filter(
Event.event_name == 'knowledge_picked',
Event.org_id == self.organisation_id,
Event.event_property['knowledge_name'].astext == knowledge_name
).group_by(
Event.event_property['knowledge_name']
).first()
if knowledge_used_event is None:
return {}
knowledge_data = {
'knowledge_unique_agents': knowledge_used_event.knowledge_unique_agents,
'knowledge_calls': self.session.query(
EventAlias
).filter(
EventAlias.event_property['tool_name'].astext == 'Knowledge Search',
EventAlias.event_name == 'tool_used',
EventAlias.org_id == self.organisation_id,
EventAlias.agent_id.in_(self.session.query(Event.agent_id).filter(
Event.event_name == 'knowledge_picked',
Event.org_id == self.organisation_id,
Event.event_property['knowledge_name'].astext == knowledge_name
))
).count()
}
return knowledge_data
def get_knowledge_events_by_name(self, knowledge_name: str) -> List[Dict[str, Union[str, int, List[str]]]]:
is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first()
if not is_knowledge_valid:
raise HTTPException(status_code=404, detail="Knowledge not found")
knowledge_events = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
Event.event_name == 'knowledge_picked',
Event.event_property['knowledge_name'].astext == knowledge_name
).all()
knowledge_events = [ke for ke in knowledge_events if 'agent_execution_id' in ke.event_property]
event_runs = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed')
).all()
agent_created_events = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
Event.event_name == 'agent_created'
).all()
results = []
for knowledge_event in knowledge_events:
agent_execution_id = knowledge_event.event_property['agent_execution_id']
event_run = next((er for er in event_runs if er.agent_id == knowledge_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None)
agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == knowledge_event.agent_id), None)
model_query = self.session.query(AgentExecutionConfiguration).filter(
AgentExecutionConfiguration.agent_execution_id == agent_execution_id,
AgentExecutionConfiguration.key == 'model'
).first()
if model_query and model_query.value != 'None':
model_value = model_query.value
else:
model_value = None
try:
user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=knowledge_event.agent_id)
if user_timezone and user_timezone.value != 'None':
tz = pytz.timezone(user_timezone.value)
else:
tz = pytz.timezone('GMT')
except AttributeError:
tz = pytz.timezone('GMT')
if event_run and agent_created_event:
actual_time = knowledge_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M")
result_dict = {
'agent_execution_id': agent_execution_id,
'created_at': actual_time,
'tokens_consumed': event_run.event_property['tokens_consumed'],
'calls': event_run.event_property['calls'],
'agent_execution_name': event_run.event_property['name'],
'agent_name': agent_created_event.event_property['agent_name'],
'model': model_value if model_value else agent_created_event.event_property['model']
}
if agent_execution_id not in [i['agent_execution_id'] for i in results]:
results.append(result_dict)
results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True)
return results
================================================
FILE: superagi/apm/tools_handler.py
================================================
from typing import List, Dict, Union
from sqlalchemy import func, distinct, and_
from sqlalchemy.orm import Session
from sqlalchemy import Integer, String
from fastapi import HTTPException
from superagi.models.events import Event
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
from sqlalchemy import or_
from sqlalchemy.sql import label
from datetime import datetime
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution_config import AgentExecutionConfiguration
import pytz
class ToolsHandler:
def __init__(self, session: Session, organisation_id: int):
self.session = session
self.organisation_id = organisation_id
def get_tool_and_toolkit(self):
tools_and_toolkits = self.session.query(
func.lower(Tool.name).label('tool_name'), Toolkit.name.label('toolkit_name')).join(
Toolkit, Tool.toolkit_id == Toolkit.id).all()
return {item.tool_name.lower(): item.toolkit_name for item in tools_and_toolkits}
def calculate_tool_usage(self) -> List[Dict[str, int]]:
tool_usage = []
tool_used_subquery = self.session.query(
Event.event_property['tool_name'].label('tool_name'),
Event.agent_id
).filter_by(event_name="tool_used", org_id=self.organisation_id).subquery()
agent_count = self.session.query(
tool_used_subquery.c.tool_name,
func.count(func.distinct(tool_used_subquery.c.agent_id)).label('unique_agents')
).group_by(tool_used_subquery.c.tool_name).subquery()
total_usage = self.session.query(
tool_used_subquery.c.tool_name,
func.count(tool_used_subquery.c.tool_name).label('total_usage')
).group_by(tool_used_subquery.c.tool_name).subquery()
query = self.session.query(
agent_count.c.tool_name,
agent_count.c.unique_agents,
total_usage.c.total_usage,
).join(total_usage, total_usage.c.tool_name == agent_count.c.tool_name)
tool_and_toolkit = self.get_tool_and_toolkit()
result = query.all()
tool_usage = [{
'tool_name': row.tool_name,
'unique_agents': row.unique_agents,
'total_usage': row.total_usage,
'toolkit': tool_and_toolkit.get(row.tool_name.lower(), None)
} for row in result]
tool_usage.sort(key=lambda tool: tool['total_usage'], reverse=True)
return tool_usage
def get_tool_usage_by_name(self, tool_name: str) -> Dict[str, Dict[str, int]]:
is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first()
if not is_tool_name_valid:
raise HTTPException(status_code=404, detail="Tool not found")
tool_name_event = self.session.query(
Event.event_property['tool_name'].cast(String).label('tool_name'),
func.count(Event.id).label('tool_calls'),
func.count(distinct(Event.agent_id)).label('tool_unique_agents')
).filter(
Event.event_name == 'tool_used',
Event.org_id == self.organisation_id,
Event.event_property['tool_name'].astext == tool_name
).group_by(
Event.event_property['tool_name'].cast(String)
).first()
tool_data = {}
tool_calls = 0
tool_unique_agents = 0
if tool_name_event:
tool_calls += tool_name_event.tool_calls
tool_unique_agents += tool_name_event.tool_unique_agents
tool_data = {
'tool_calls': tool_calls,
'tool_unique_agents': tool_unique_agents
}
return tool_data
def get_tool_events_by_name(self, tool_name: str) -> List[Dict[str, Union[str, int, List[str]]]]:
is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first()
if not is_tool_name_valid:
raise HTTPException(status_code=404, detail="Tool not found")
tool_events = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
Event.event_name == 'tool_used',
Event.event_property['tool_name'].astext == tool_name
).all()
tool_events = [te for te in tool_events if 'agent_execution_id' in te.event_property]
event_runs = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed')
).all()
agent_created_events = self.session.query(Event).filter(
Event.org_id == self.organisation_id,
Event.event_name == 'agent_created'
).all()
results = []
for tool_event in tool_events:
agent_execution_id = tool_event.event_property['agent_execution_id']
event_run = next((er for er in event_runs if er.agent_id == tool_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None)
agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == tool_event.agent_id), None)
model_query = self.session.query(AgentExecutionConfiguration).filter(
AgentExecutionConfiguration.agent_execution_id == agent_execution_id,
AgentExecutionConfiguration.key == 'model'
).first()
if model_query and model_query.value != 'None':
model_value = model_query.value
else:
model_value = None
try:
user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=tool_event.agent_id)
if user_timezone and user_timezone.value != 'None':
tz = pytz.timezone(user_timezone.value)
else:
tz = pytz.timezone('GMT')
except AttributeError:
tz = pytz.timezone('GMT')
if event_run and agent_created_event:
actual_time = tool_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M")
other_tools_events = self.session.query(
Event
).filter(
Event.org_id == self.organisation_id,
Event.event_name == 'tool_used',
Event.event_property['tool_name'].astext != tool_name,
Event.agent_id == tool_event.agent_id,
Event.id.between(tool_event.id, event_run.id)
).all()
other_tools = [ote.event_property['tool_name'] for ote in other_tools_events]
result_dict = {
'created_at': actual_time,
'agent_execution_id': agent_execution_id,
'tokens_consumed': event_run.event_property['tokens_consumed'],
'calls': event_run.event_property['calls'],
'agent_execution_name': event_run.event_property['name'],
'other_tools': other_tools,
'agent_name': agent_created_event.event_property['agent_name'],
'model': model_value if model_value else agent_created_event.event_property['model']
}
if agent_execution_id not in [i['agent_execution_id'] for i in results]:
results.append(result_dict)
results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True)
return results
================================================
FILE: superagi/config/__init__.py
================================================
================================================
FILE: superagi/config/config.py
================================================
import os
from pydantic import BaseSettings
from pathlib import Path
import yaml
from superagi.lib.logger import logger
CONFIG_FILE = "config.yaml"
class Config(BaseSettings):
class Config:
env_file_encoding = "utf-8"
extra = "allow" # Allow extra fields
@classmethod
def load_config(cls, config_file: str) -> dict:
# If config file exists, read it
if os.path.exists(config_file):
with open(config_file, "r") as file:
config_data = yaml.safe_load(file)
if config_data is None:
config_data = {}
else:
# If config file doesn't exist, prompt for credentials and create new file
logger.info("\033[91m\033[1m"
+ "\nConfig file not found. Enter required keys and values."
+ "\033[0m\033[0m")
config_data = {}
with open(config_file, "w") as file:
yaml.dump(config_data, file, default_flow_style=False)
# Merge environment variables and config data
env_vars = dict(os.environ)
config_data = {**config_data, **env_vars}
return config_data
def __init__(self, config_file: str, **kwargs):
config_data = self.load_config(config_file)
super().__init__(**config_data, **kwargs)
def get_config(self, key: str, default: str = None) -> str:
return self.dict().get(key, default)
ROOT_DIR = os.path.dirname(Path(__file__).parent.parent)
_config_instance = Config(ROOT_DIR + "/" + CONFIG_FILE)
def get_config(key: str, default: str = None) -> str:
return _config_instance.get_config(key, default)
================================================
FILE: superagi/controllers/__init__.py
================================================
================================================
FILE: superagi/controllers/agent.py
================================================
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from sqlalchemy import desc
import ast
from pytz import timezone
from sqlalchemy import func, or_
from superagi.models.agent import Agent
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.agent_template import AgentTemplate
from superagi.models.project import Project
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent_execution import AgentExecution
from superagi.models.tool import Tool
from superagi.controllers.types.agent_schedule import AgentScheduleInput
from superagi.controllers.types.agent_with_config import AgentConfigInput
from superagi.controllers.types.agent_with_config_schedule import AgentConfigSchedule
from jsonmerge import merge
from datetime import datetime
import json
from superagi.models.toolkit import Toolkit
from superagi.models.knowledges import Knowledges
from sqlalchemy import func
# from superagi.types.db import AgentOut, AgentIn
from superagi.helper.auth import check_auth
from superagi.apm.event_handler import EventHandler
from superagi.models.workflows.iteration_workflow import IterationWorkflow
router = APIRouter()
class AgentOut(BaseModel):
id: int
name: str
project_id: int
description: str
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class AgentIn(BaseModel):
name: str
project_id: int
description: str
class Config:
orm_mode = True
@router.post("/create", status_code=201)
def create_agent_with_config(agent_with_config: AgentConfigInput,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent with configurations.
Args:
agent_with_config (AgentConfigInput): Data for creating a new agent with configurations.
- name (str): Name of the agent.
- project_id (int): Identifier of the associated project.
- description (str): Description of the agent.
- goal (List[str]): List of goals for the agent.
- constraints (List[str]): List of constraints for the agent.
- tools (List[int]): List of tool identifiers associated with the agent.
- exit (str): Exit condition for the agent.
- iteration_interval (int): Interval between iterations for the agent.
- model (str): Model information for the agent.
- permission_type (str): Permission type for the agent.
- LTM_DB (str): LTM database for the agent.
- max_iterations (int): Maximum number of iterations for the agent.
- user_timezone (string): Timezone of the user
Returns:
dict: Dictionary containing the created agent's ID, execution ID, name, and content type.
Raises:
HTTPException (status_code=404): If the associated project or any of the tools is not found.
"""
project = db.session.query(Project).get(agent_with_config.project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
invalid_tools = Tool.get_invalid_tools(agent_with_config.tools, db.session)
if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id)
raise HTTPException(status_code=404,
detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")
agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
toolkit_ids=agent_with_config.toolkits)
agent_with_config.tools.extend(agent_toolkit_tools)
db_agent = Agent.create_agent_with_config(db, agent_with_config)
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
# Creating an execution with RUNNING status
execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)
agent_execution_configs = {
"goal": agent_with_config.goal,
"instruction": agent_with_config.instruction,
"constraints": agent_with_config.constraints,
"toolkits": agent_with_config.toolkits,
"exit": agent_with_config.exit,
"tools": agent_with_config.tools,
"iteration_interval": agent_with_config.iteration_interval,
"model": agent_with_config.model,
"permission_type": agent_with_config.permission_type,
"LTM_DB": agent_with_config.LTM_DB,
"max_iterations": agent_with_config.max_iterations,
"user_timezone": agent_with_config.user_timezone,
"knowledge": agent_with_config.knowledge
}
db.session.add(execution)
db.session.commit()
db.session.flush()
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
agent_execution_configs=agent_execution_configs)
agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
organisation = agent.get_agent_organisation(db.session)
EventHandler(session=db.session).create_event('run_created',
{'agent_execution_id': execution.id,
'agent_execution_name': execution.name},
db_agent.id,
organisation.id if organisation else 0),
if agent_with_config.knowledge:
knowledge_name = db.session.query(Knowledges.name).filter(Knowledges.id == agent_with_config.knowledge).first()[0]
EventHandler(session=db.session).create_event('knowledge_picked',
{'knowledge_name': knowledge_name,
'agent_execution_id': execution.id},
db_agent.id,
organisation.id if organisation else 0)
EventHandler(session=db.session).create_event('agent_created',
{'agent_name': agent_with_config.name,
'model': agent_with_config.model},
db_agent.id,
organisation.id if organisation else 0)
db.session.commit()
return {
"id": db_agent.id,
"execution_id": execution.id,
"name": db_agent.name,
"contentType": "Agents"
}
@router.post("/schedule", status_code=201)
def create_and_schedule_agent(agent_config_schedule: AgentConfigSchedule,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent with configurations and scheduling.
Args:
agent_with_config_schedule (AgentConfigSchedule): Data for creating a new agent with configurations and scheduling.
Returns:
dict: Dictionary containing the created agent's ID, name, content type and schedule ID of the agent.
Raises:
HTTPException (status_code=500): If the associated agent fails to get scheduled.
"""
project = db.session.query(Project).get(agent_config_schedule.agent_config.project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
agent_config = agent_config_schedule.agent_config
invalid_tools = Tool.get_invalid_tools(agent_config.tools, db.session)
if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id)
raise HTTPException(status_code=404,
detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")
agent_toolkit_tools = Toolkit.fetch_tool_ids_from_toolkit(session=db.session,
toolkit_ids=agent_config.toolkits)
agent_config.tools.extend(agent_toolkit_tools)
db_agent = Agent.create_agent_with_config(db, agent_config)
# Update the agent_id of schedule before scheduling the agent
agent_schedule = agent_config_schedule.schedule
# Create a new agent schedule
agent_schedule = AgentSchedule(
agent_id=db_agent.id,
start_time=agent_schedule.start_time,
next_scheduled_time=agent_schedule.start_time,
recurrence_interval=agent_schedule.recurrence_interval,
expiry_date=agent_schedule.expiry_date,
expiry_runs=agent_schedule.expiry_runs,
current_runs=0,
status="SCHEDULED"
)
agent_schedule.agent_id = db_agent.id
db.session.add(agent_schedule)
db.session.commit()
if agent_schedule.id is None:
raise HTTPException(status_code=500, detail="Failed to schedule agent")
agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
organisation = agent.get_agent_organisation(db.session)
EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_config.name,
'model': agent_config.model}, db_agent.id,
organisation.id if organisation else 0)
db.session.commit()
return {
"id": db_agent.id,
"name": db_agent.name,
"contentType": "Agents",
"schedule_id": agent_schedule.id
}
@router.post("/stop/schedule", status_code=200)
def stop_schedule(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Stopping the scheduling for a given agent.
Args:
agent_id (int): Identifier of the Agent
Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).
Raises:
HTTPException (status_code=404): If the agent schedule is not found.
"""
agent_to_delete = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
AgentSchedule.status == "SCHEDULED").first()
if not agent_to_delete:
raise HTTPException(status_code=404, detail="Schedule not found")
agent_to_delete.status = "STOPPED"
db.session.commit()
@router.put("/edit/schedule", status_code=200)
def edit_schedule(schedule: AgentScheduleInput,
Authorize: AuthJWT = Depends(check_auth)):
"""
Edit the scheduling for a given agent.
Args:
agent_id (int): Identifier of the Agent
schedule (AgentSchedule): New schedule data
Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).
Raises:
HTTPException (status_code=404): If the agent schedule is not found.
"""
agent_to_edit = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == schedule.agent_id, AgentSchedule.status == "SCHEDULED").first()
if not agent_to_edit:
raise HTTPException(status_code=404, detail="Schedule not found")
# Update agent schedule with new data
agent_to_edit.start_time = schedule.start_time
agent_to_edit.next_scheduled_time = schedule.start_time
agent_to_edit.recurrence_interval = schedule.recurrence_interval
agent_to_edit.expiry_date = schedule.expiry_date
agent_to_edit.expiry_runs = schedule.expiry_runs
db.session.commit()
@router.get("/get/schedule_data/{agent_id}")
def get_schedule_data(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Get the scheduling data for a given agent.
Args:
agent_id (int): Identifier of the Agent
Raises:
HTTPException (status_code=404): If the agent schedule is not found.
Returns:
current_datetime (DateTime): Current Date and Time.
recurrence_interval (String): Time interval for recurring schedule run.
expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
expiry_runs (Integer): The number of runs before the agent expires.
"""
agent = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id,
AgentSchedule.status == "SCHEDULED").first()
if not agent:
raise HTTPException(status_code=404, detail="Agent Schedule not found")
user_timezone = db.session.query(AgentConfiguration).filter(AgentConfiguration.key == "user_timezone",
AgentConfiguration.agent_id == agent_id).first()
if user_timezone and user_timezone.value != "None":
tzone = timezone(user_timezone.value)
else:
tzone = timezone('GMT')
current_datetime = datetime.now(tzone).strftime("%d/%m/%Y %I:%M %p")
return {
"current_datetime": current_datetime,
"start_date": agent.start_time.astimezone(tzone).strftime("%d %b %Y"),
"start_time": agent.start_time.astimezone(tzone).strftime("%I:%M %p"),
"recurrence_interval": agent.recurrence_interval if agent.recurrence_interval else None,
"expiry_date": agent.expiry_date.astimezone(tzone).strftime("%d/%m/%Y") if agent.expiry_date else None,
"expiry_runs": agent.expiry_runs if agent.expiry_runs != -1 else None
}
@router.get("/get/project/{project_id}")
def get_agents_by_project_id(project_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get all agents by project ID.
Args:
project_id (int): Identifier of the project.
Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).
Returns:
list: List of agents associated with the project, including their status and scheduling information.
Raises:
HTTPException (status_code=404): If the project is not found.
"""
# Checking for project
project = db.session.query(Project).get(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
agents = db.session.query(Agent).filter(Agent.project_id == project_id, or_(or_(Agent.is_deleted == False, Agent.is_deleted is None), Agent.is_deleted is None)).all()
new_agents, new_agents_sorted = [], []
for agent in agents:
agent_dict = vars(agent)
agent_id = agent.id
# Query the AgentExecution table using the agent ID
executions = db.session.query(AgentExecution).filter_by(agent_id=agent_id).all()
is_running = False
for execution in executions:
if execution.status == "RUNNING":
is_running = True
break
# Check if the agent is scheduled
is_scheduled = db.session.query(AgentSchedule).filter_by(agent_id=agent_id, status="SCHEDULED").first() is not None
new_agent = {
**agent_dict,
'is_running': is_running,
'is_scheduled': is_scheduled
}
new_agents.append(new_agent)
new_agents_sorted = sorted(new_agents, key=lambda agent: agent['is_running'] == True, reverse=True)
return new_agents_sorted
@router.put("/delete/{agent_id}", status_code=200)
def delete_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Delete an existing Agent
- Updates the is_deleted flag: Executes a soft delete
- AgentExecutions are updated to: "TERMINATED" if agentexecution is created, All the agent executions are updated
- AgentExecutionPermission is set to: "REJECTED" if agentexecutionpersmision is created
Args:
agent_id (int): Identifier of the Agent to delete
Returns:
A dictionary containing a "success" key with the value True to indicate a successful delete.
Raises:
HTTPException (Status Code=404): If the Agent or associated Project is not found or deleted already.
"""
db_agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
db_agent_executions = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).all()
db_agent_schedule = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id, AgentSchedule.status == "SCHEDULED").first()
if not db_agent or db_agent.is_deleted:
raise HTTPException(status_code=404, detail="agent not found")
# Deletion Procedure
db_agent.is_deleted = True
if db_agent_executions:
# Updating all the RUNNING executions to TERMINATED
for db_agent_execution in db_agent_executions:
db_agent_execution.status = "TERMINATED"
if db_agent_schedule:
# Updating the schedule status to STOPPED
db_agent_schedule.status = "STOPPED"
db.session.commit()
================================================
FILE: superagi/controllers/agent_execution.py
================================================
from datetime import datetime
from typing import Optional, Union, List
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel
from pydantic.fields import List
from superagi.controllers.types.agent_execution_config import AgentRunIn
from superagi.helper.time_helper import get_time_difference
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.worker import execute_agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent import Agent
from superagi.models.models import Models
from fastapi import APIRouter
from sqlalchemy import desc
from superagi.helper.auth import check_auth
from superagi.controllers.types.agent_schedule import AgentScheduleInput
from superagi.apm.event_handler import EventHandler
from superagi.controllers.tool import ToolOut
from superagi.models.agent_config import AgentConfiguration
from superagi.models.knowledges import Knowledges
router = APIRouter()
class AgentExecutionOut(BaseModel):
id: int
status: str
name: str
agent_id: int
last_execution_time: datetime
num_of_calls: int
num_of_tokens: int
current_agent_step_id: int
permission_id: Optional[int]
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class AgentExecutionIn(BaseModel):
status: Optional[str]
name: Optional[str]
agent_id: Optional[int]
last_execution_time: Optional[datetime]
num_of_calls: Optional[int]
num_of_tokens: Optional[int]
current_agent_step_id: Optional[int]
permission_id: Optional[int]
goal: Optional[List[str]]
instruction: Optional[List[str]]
class config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=AgentExecutionOut, status_code=201)
def create_agent_execution(agent_execution: AgentExecutionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent execution/run.
Args:
agent_execution (AgentExecution): The agent execution data.
Returns:
AgentExecution: The created agent execution.
Raises:
HTTPException (Status Code=404): If the agent is not found.
"""
agent = db.session.query(Agent).filter(Agent.id == agent_execution.agent_id, Agent.is_deleted == False).first()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
iteration_workflow_step_id=iteration_step_id)
agent_execution_configs = {
"goal": agent_execution.goal,
"instruction": agent_execution.instruction
}
agent_configs = db.session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_execution.agent_id).all()
keys_to_exclude = ["goal", "instruction"]
for agent_config in agent_configs:
if agent_config.key not in keys_to_exclude:
if agent_config.key == "toolkits":
if agent_config.value:
toolkits = [int(item) for item in agent_config.value.strip('{}').split(',') if item.strip() and item != '[]']
agent_execution_configs[agent_config.key] = toolkits
else:
agent_execution_configs[agent_config.key] = []
elif agent_config.key == "constraints":
if agent_config.value:
agent_execution_configs[agent_config.key] = agent_config.value
else:
agent_execution_configs[agent_config.key] = []
else:
agent_execution_configs[agent_config.key] = agent_config.value
db.session.add(db_agent_execution)
db.session.commit()
db.session.flush()
#update status from CREATED to RUNNING
db_agent_execution.status = "RUNNING"
db.session.commit()
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
agent_execution_configs=agent_execution_configs)
organisation = agent.get_agent_organisation(db.session)
agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id)
EventHandler(session=db.session).create_event('run_created',
{'agent_execution_id': db_agent_execution.id,
'agent_execution_name':db_agent_execution.name},
agent_execution.agent_id,
organisation.id if organisation else 0)
if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
if knowledge_name is not None:
EventHandler(session=db.session).create_event('knowledge_picked',
{'knowledge_name': knowledge_name,
'agent_execution_id': db_agent_execution.id},
agent_execution.agent_id,
organisation.id if organisation else 0)
Models.api_key_from_configurations(session=db.session, organisation_id=organisation.id)
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
return db_agent_execution
@router.post("/add_run", status_code = 201)
def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent run with all the information(goals, instructions, model, etc).
Args:
agent_execution (AgentExecution): The agent execution data.
Returns:
AgentExecution: The created agent execution.
Raises:
HTTPException (Status Code=404): If the agent is not found.
"""
agent = db.session.query(Agent).filter(Agent.id == agent_execution.agent_id, Agent.is_deleted == False).first()
if not agent:
raise HTTPException(status_code = 404, detail = "Agent not found")
#Update the agent configurations table with the data of the latest agent execution
AgentConfiguration.update_agent_configurations_table(session=db.session, agent_id=agent_execution.agent_id, updated_details=agent_execution)
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
iteration_workflow_step_id=iteration_step_id)
agent_execution_configs = {
"goal": agent_execution.goal,
"instruction": agent_execution.instruction,
"constraints": agent_execution.constraints,
"toolkits": agent_execution.toolkits,
"exit": agent_execution.exit,
"tools": agent_execution.tools,
"iteration_interval": agent_execution.iteration_interval,
"model": agent_execution.model,
"permission_type": agent_execution.permission_type,
"LTM_DB": agent_execution.LTM_DB,
"max_iterations": agent_execution.max_iterations,
"user_timezone": agent_execution.user_timezone,
"knowledge": agent_execution.knowledge
}
db.session.add(db_agent_execution)
db.session.commit()
db.session.flush()
#update status from CREATED to RUNNING
db_agent_execution.status = "RUNNING"
db.session.commit()
AgentExecutionConfiguration.add_or_update_agent_execution_config(session = db.session, execution = db_agent_execution,
agent_execution_configs = agent_execution_configs)
organisation = agent.get_agent_organisation(db.session)
EventHandler(session=db.session).create_event('run_created',
{'agent_execution_id': db_agent_execution.id,
'agent_execution_name':db_agent_execution.name},
agent_execution.agent_id,
organisation.id if organisation else 0)
agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id)
if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
if knowledge_name is not None:
EventHandler(session=db.session).create_event('knowledge_picked',
{'knowledge_name': knowledge_name,
'agent_execution_id': db_agent_execution.id},
agent_execution.agent_id,
organisation.id if organisation else 0)
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
return db_agent_execution
@router.post("/schedule", status_code=201)
def schedule_existing_agent(agent_schedule: AgentScheduleInput,
Authorize: AuthJWT = Depends(check_auth)):
"""
Schedules an already existing agent.
Args:
agent_schedule (AgentScheduleInput): Data for creating a scheduling for an existing agent.
agent_id (Integer): The ID of the agent being scheduled.
start_time (DateTime): The date and time from which the agent is scheduled.
recurrence_interval (String): Stores "none" if not recurring,
or a time interval like '2 Weeks', '1 Month', '2 Minutes' based on input.
expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
expiry_runs (Integer): The number of runs before the agent expires.
Returns:
Schedule ID: Unique Schedule ID of the Agent.
Raises:
HTTPException (Status Code=500): If the agent fails to get scheduled.
"""
# Check if the agent is already scheduled
scheduled_agent = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_schedule.agent_id,
AgentSchedule.status == "SCHEDULED").first()
if scheduled_agent:
# Update the old record with new data
scheduled_agent.start_time = agent_schedule.start_time
scheduled_agent.next_scheduled_time = agent_schedule.start_time
scheduled_agent.recurrence_interval = agent_schedule.recurrence_interval
scheduled_agent.expiry_date = agent_schedule.expiry_date
scheduled_agent.expiry_runs = agent_schedule.expiry_runs
db.session.commit()
else:
# Schedule the agent
scheduled_agent = AgentSchedule(
agent_id=agent_schedule.agent_id,
start_time=agent_schedule.start_time,
next_scheduled_time=agent_schedule.start_time,
recurrence_interval=agent_schedule.recurrence_interval,
expiry_date=agent_schedule.expiry_date,
expiry_runs=agent_schedule.expiry_runs,
current_runs=0,
status="SCHEDULED"
)
db.session.add(scheduled_agent)
db.session.commit()
schedule_id = scheduled_agent.id
if schedule_id is None:
raise HTTPException(status_code=500, detail="Failed to schedule agent")
return {
"schedule_id": schedule_id
}
@router.get("/get/{agent_execution_id}", response_model=AgentExecutionOut)
def get_agent_execution(agent_execution_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get an agent execution by agent_execution_id.
Args:
agent_execution_id (int): The ID of the agent execution.
Returns:
AgentExecution: The requested agent execution.
Raises:
HTTPException (Status Code=404): If the agent execution is not found.
"""
if (
db_agent_execution := db.session.query(AgentExecution)
.filter(AgentExecution.id == agent_execution_id)
.first()
):
return db_agent_execution
else:
raise HTTPException(status_code=404, detail="Agent execution not found")
@router.put("/update/{agent_execution_id}", response_model=AgentExecutionOut)
def update_agent_execution(agent_execution_id: int,
agent_execution: AgentExecutionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""Update details of particular agent_execution by agent_execution_id"""
db_agent_execution = db.session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
if agent_execution.status == "COMPLETED":
raise HTTPException(status_code=400, detail="Invalid Request")
if not db_agent_execution:
raise HTTPException(status_code=404, detail="Agent Execution not found")
if agent_execution.agent_id:
if agent := db.session.query(Agent).get(agent_execution.agent_id):
db_agent_execution.agent_id = agent.id
else:
raise HTTPException(status_code=404, detail="Agent not found")
if agent_execution.status not in [
"CREATED",
"RUNNING",
"PAUSED",
"COMPLETED",
"TERMINATED",
]:
raise HTTPException(status_code=400, detail="Invalid Request")
db_agent_execution.status = agent_execution.status
db_agent_execution.last_execution_time = datetime.now()
db.session.commit()
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
return db_agent_execution
@router.get("/get/agents/status/{status}")
def agent_list_by_status(status: str,
Authorize: AuthJWT = Depends(check_auth)):
"""Get list of all agent_ids for a given status"""
running_agent_ids = db.session.query(AgentExecution.agent_id).filter(
AgentExecution.status == status.upper()).distinct().all()
agent_ids = [agent_id for (agent_id) in running_agent_ids]
return agent_ids
@router.get("/get/agent/{agent_id}")
def list_running_agents(agent_id: str,
Authorize: AuthJWT = Depends(check_auth)):
"""Get all running state agents"""
executions = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).order_by(
desc(AgentExecution.status == 'RUNNING'), desc(AgentExecution.last_execution_time)).all()
for execution in executions:
execution.time_difference = get_time_difference(execution.last_execution_time,str(datetime.now()))
return executions
@router.get("/get/latest/agent/project/{project_id}")
def get_agent_by_latest_execution(project_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""Get latest executing agent details"""
latest_execution = (
db.session.query(AgentExecution)
.join(Agent, AgentExecution.agent_id == Agent.id)
.filter(Agent.project_id == project_id, Agent.is_deleted == False)
.order_by(desc(AgentExecution.last_execution_time))
.first()
)
isRunning = False
if latest_execution.status == "RUNNING":
isRunning = True
agent = db.session.query(Agent).filter(Agent.id == latest_execution.agent_id).first()
return {
"agent_id": latest_execution.agent_id,
"project_id": project_id,
"created_at": agent.created_at,
"description": agent.description,
"updated_at": agent.updated_at,
"name": agent.name,
"id": agent.id,
"status": isRunning,
"contentType": "Agents"
}
================================================
FILE: superagi/controllers/agent_execution_config.py
================================================
import ast
import json
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from typing import Optional, Union
from sqlalchemy import func, or_
from sqlalchemy import desc
from superagi.helper.auth import check_auth
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.tool import Tool
from superagi.models.knowledges import Knowledges
router = APIRouter()
@router.get("/details/agent_id/{agent_id}/agent_execution_id/{agent_execution_id}")
def get_agent_execution_configuration(agent_id : Union[int, None, str],
agent_execution_id: Union[int, None, str],
Authorize: AuthJWT = Depends(check_auth)):
"""
Get the agent configuration using the agent ID and the agent execution ID.
Args:
agent_id (int): Identifier of the agent.
agent_execution_id (int): Identifier of the agent execution.
Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).
Returns:
dict: Agent configuration including its details.
Raises:
HTTPException (status_code=404): If the agent is not found or deleted.
HTTPException (status_code=404): If the agent_id or the agent_execution_id is undefined.
"""
# Check
if isinstance(agent_id, str):
raise HTTPException(status_code = 404, detail = "Agent Id undefined")
if isinstance(agent_execution_id, str):
raise HTTPException(status_code = 404, detail = "Agent Execution Id undefined")
# Define the agent_config keys to fetch
agent = db.session.query(Agent).filter(agent_id == Agent.id,or_(Agent.is_deleted == False)).first()
if not agent:
raise HTTPException(status_code = 404, detail = "Agent not found")
#If the agent_execution_id received is -1 then the agent_execution_id is set as the most recent execution
if agent_execution_id == -1:
agent_execution = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).order_by(desc(AgentExecution.created_at)).first()
if agent_execution: agent_execution_id = agent_execution.id
#Fetch agent id from agent execution id and check whether the agent_id received is correct or not.
if agent_execution_id!=-1:
agent_execution_config = AgentExecution.get_agent_execution_from_id(db.session, agent_execution_id)
if agent_execution_config is None:
raise HTTPException(status_code = 404, detail = "Agent Execution not found")
agent_id_from_execution_id = agent_execution_config.agent_id
if agent_id != agent_id_from_execution_id:
raise HTTPException(status_code = 404, detail = "Wrong agent id")
# Query the AgentConfiguration table and the AgentExecuitonConfiguration table for all the keys
results_agent = db.session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
if agent_execution_id!=-1: results_agent_execution = db.session.query(AgentExecutionConfiguration).filter(AgentExecutionConfiguration.agent_execution_id == agent_execution_id).all()
total_calls = db.session.query(func.sum(AgentExecution.num_of_calls)).filter(
AgentExecution.agent_id == agent_id).scalar()
total_tokens = db.session.query(func.sum(AgentExecution.num_of_tokens)).filter(
AgentExecution.agent_id == agent_id).scalar()
response = {}
if agent_execution_id!=-1:
response = AgentExecutionConfiguration.build_agent_execution_config(db.session, agent, results_agent, results_agent_execution, total_calls, total_tokens)
else:
response = AgentExecutionConfiguration.build_scheduled_agent_execution_config(db.session, agent, results_agent, total_calls, total_tokens)
# Close the session
db.session.close()
return response
================================================
FILE: superagi/controllers/agent_execution_feed.py
================================================
import asyncio
from datetime import datetime
import time
from typing import Optional
from fastapi import APIRouter, BackgroundTasks
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from sqlalchemy.sql import asc
from superagi.agent.task_queue import TaskQueue
from superagi.helper.auth import check_auth
from superagi.helper.time_helper import get_time_difference
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.helper.feed_parser import parse_feed
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.lib.logger import logger
from superagi.agent.types.agent_workflow_step_action_types import AgentWorkflowStepAction
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_wait import AgentWorkflowStepWait
import re
# from superagi.types.db import AgentExecutionFeedOut, AgentExecutionFeedIn
router = APIRouter()
class AgentExecutionFeedOut(BaseModel):
id: int
agent_execution_id: int
agent_id: int
feed: str
role: str
extra_info: Optional[str]
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class AgentExecutionFeedIn(BaseModel):
id: int
agent_execution_id: int
agent_id: int
feed: str
role: str
extra_info: str
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=AgentExecutionFeedOut, status_code=201)
def create_agent_execution_feed(agent_execution_feed: AgentExecutionFeedIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Add a new agent execution feed.
Args:
agent_execution_feed (AgentExecutionFeed): The data for the agent execution feed.
Returns:
AgentExecutionFeed: The newly created agent execution feed.
Raises:
HTTPException (Status Code=404): If the associated agent execution is not found.
"""
agent_execution = db.session.query(AgentExecution).get(agent_execution_feed.agent_execution_id)
if not agent_execution:
raise HTTPException(status_code=404, detail="Agent Execution not found")
db_agent_execution_feed = AgentExecutionFeed(agent_execution_id=agent_execution_feed.agent_execution_id,
feed=agent_execution_feed.feed, type=agent_execution_feed.type,
extra_info=agent_execution_feed.extra_info,
feed_group_id=agent_execution.current_feed_group_id)
db.session.add(db_agent_execution_feed)
db.session.commit()
return db_agent_execution_feed
@router.get("/get/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut)
def get_agent_execution_feed(agent_execution_feed_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get an agent execution feed by agent_execution_feed_id.
Args:
agent_execution_feed_id (int): The ID of the agent execution feed.
Returns:
AgentExecutionFeed: The agent execution feed with the specified ID.
Raises:
HTTPException (Status Code=404): If the agent execution feed is not found.
"""
db_agent_execution_feed = db.session.query(AgentExecutionFeed).filter(
AgentExecutionFeed.id == agent_execution_feed_id).first()
if not db_agent_execution_feed:
raise HTTPException(status_code=404, detail="agent_execution_feed not found")
return db_agent_execution_feed
@router.put("/update/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut)
def update_agent_execution_feed(agent_execution_feed_id: int,
agent_execution_feed: AgentExecutionFeedIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a particular agent execution feed.
Args:
agent_execution_feed_id (int): The ID of the agent execution feed to update.
agent_execution_feed (AgentExecutionFeed): The updated agent execution feed.
Returns:
AgentExecutionFeed: The updated agent execution feed.
Raises:
HTTPException (Status Code=404): If the agent execution feed or agent execution is not found.
"""
db_agent_execution_feed = db.session.query(AgentExecutionFeed).filter(
AgentExecutionFeed.id == agent_execution_feed_id).first()
if not db_agent_execution_feed:
raise HTTPException(status_code=404, detail="Agent Execution Feed not found")
if agent_execution_feed.agent_execution_id:
agent_execution = db.session.query(AgentExecution).get(agent_execution_feed.agent_execution_id)
if not agent_execution:
raise HTTPException(status_code=404, detail="Agent Execution not found")
db_agent_execution_feed.agent_execution_id = agent_execution.id
if agent_execution_feed.type is not None:
db_agent_execution_feed.type = agent_execution_feed.type
if agent_execution_feed.feed is not None:
db_agent_execution_feed.feed = agent_execution_feed.feed
# if agent_execution_feed.extra_info is not None:
# db_agent_execution_feed.extra_info = agent_execution_feed.extra_info
db.session.commit()
return db_agent_execution_feed
@router.get("/get/execution/{agent_execution_id}")
def get_agent_execution_feed(agent_execution_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get agent execution feed with other execution details.
Args:
agent_execution_id (int): The ID of the agent execution.
Returns:
dict: The agent execution status and feeds.
Raises:
HTTPException (Status Code=400): If the agent run is not found.
"""
agent_execution = db.session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
if agent_execution is None:
raise HTTPException(status_code=400, detail="Agent Run not found!")
feeds = db.session.query(AgentExecutionFeed).filter_by(agent_execution_id=agent_execution_id).order_by(
asc(AgentExecutionFeed.created_at)).all()
# # parse json
final_feeds = []
error = ""
for feed in feeds:
if feed.error_message:
if (agent_execution.last_shown_error_id is None) or (feed.id > agent_execution.last_shown_error_id):
#new error occured
error = feed.error_message
agent_execution.last_shown_error_id = feed.id
agent_execution.status = "ERROR_PAUSED"
db.session.commit()
if feed.id == agent_execution.last_shown_error_id and agent_execution.status == "ERROR_PAUSED":
error = feed.error_message
if feed.feed != "" and re.search(r"The current time and date is\s(\w{3}\s\w{3}\s\s?\d{1,2}\s\d{2}:\d{2}:\d{2}\s\d{4})",feed.feed) == None :
final_feeds.append(parse_feed(feed))
# get all permissions
execution_permissions = db.session.query(AgentExecutionPermission).\
filter_by(agent_execution_id=agent_execution_id). \
order_by(asc(AgentExecutionPermission.created_at)).all()
permissions = [
{
"id": permission.id,
"created_at": permission.created_at,
"response": permission.user_feedback,
"status": permission.status,
"tool_name": permission.tool_name,
"question": permission.question,
"user_feedback": permission.user_feedback,
"time_difference": get_time_difference(permission.created_at, str(datetime.now()))
} for permission in execution_permissions
]
waiting_period = None
if agent_execution.status == AgentWorkflowStepAction.WAIT_STEP.value:
workflow_step = AgentWorkflowStep.find_by_id(db.session, agent_execution.current_agent_step_id)
waiting_period = (AgentWorkflowStepWait.find_by_id(db.session, workflow_step.action_reference_id)).delay
return {
"status": agent_execution.status,
"feeds": final_feeds,
"permissions": permissions,
"waiting_period": waiting_period,
"errors": error
}
@router.get("/get/tasks/{agent_execution_id}")
def get_execution_tasks(agent_execution_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get agent execution tasks and completed tasks.
Args:
agent_execution_id (int): The ID of the agent execution.
Returns:
dict: The tasks and completed tasks for the agent execution.
"""
task_queue = TaskQueue(str(agent_execution_id))
tasks = []
for task in task_queue.get_tasks():
tasks.append({"name": task})
completed_tasks = []
for task in reversed(task_queue.get_completed_tasks()):
completed_tasks.append({"name": task['task']})
return {
"tasks": tasks,
"completed_tasks": completed_tasks
}
================================================
FILE: superagi/controllers/agent_execution_permission.py
================================================
from datetime import datetime
from typing import Annotated
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Body
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.worker import execute_agent
from fastapi import APIRouter
from superagi.helper.auth import check_auth
# from superagi.types.db import AgentExecutionPermissionOut, AgentExecutionPermissionIn
router = APIRouter()
class AgentExecutionPermissionOut(BaseModel):
id: int
agent_execution_id: int
agent_id: int
status: str
tool_name: str
user_feedback: str
assistant_reply: str
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class AgentExecutionPermissionIn(BaseModel):
agent_execution_id: int
agent_id: int
status: str
tool_name: str
user_feedback: str
assistant_reply: str
class Config:
orm_mode = True
@router.get("/get/{agent_execution_permission_id}")
def get_agent_execution_permission(agent_execution_permission_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get an agent execution permission by its ID.
Args:
agent_execution_permission_id (int): The ID of the agent execution permission.
Authorize (AuthJWT, optional): Authentication object. Defaults to Depends(check_auth).
Raises:
HTTPException: If the agent execution permission is not found.
Returns:
AgentExecutionPermission: The requested agent execution permission.
"""
db_agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id)
if not db_agent_execution_permission:
raise HTTPException(status_code=404, detail="Agent execution permission not found")
return db_agent_execution_permission
@router.post("/add", response_model=AgentExecutionPermissionOut)
def create_agent_execution_permission(
agent_execution_permission: AgentExecutionPermissionIn
, Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent execution permission.
Args:
agent_execution_permission : An instance of AgentExecutionPermission model as json.
Authorize (AuthJWT, optional): Authorization token, by default depends on the check_auth function.
Returns:
new_agent_execution_permission: A newly created agent execution permission instance.
"""
new_agent_execution_permission = AgentExecutionPermission(**agent_execution_permission.dict())
db.session.add(new_agent_execution_permission)
db.session.commit()
return new_agent_execution_permission
@router.patch("/update/{agent_execution_permission_id}",
response_model=AgentExecutionPermissionIn)
def update_agent_execution_permission(agent_execution_permission_id: int,
agent_execution_permission: AgentExecutionPermissionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update an AgentExecutionPermission in the database.
Given an agent_execution_permission_id and the updated agent_execution_permission, this function updates the
corresponding AgentExecutionPermission in the database. If the AgentExecutionPermission is not found, an HTTPException
is raised.
Args:
agent_execution_permission_id (int): The ID of the AgentExecutionPermission to update.
agent_execution_permission : The updated AgentExecutionPermission object as json.
Authorize (AuthJWT, optional): Dependency to authenticate the user.
Returns:
db_agent_execution_permission (AgentExecutionPermission): The updated AgentExecutionPermission in the database.
Raises:
HTTPException: If the AgentExecutionPermission is not found in the database.
"""
db_agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id)
if not db_agent_execution_permission:
raise HTTPException(status_code=404, detail="Agent execution permission not found")
for key, value in agent_execution_permission.dict().items():
setattr(db_agent_execution_permission, key, value)
db.session.commit()
return db_agent_execution_permission
@router.put("/update/status/{agent_execution_permission_id}")
def update_agent_execution_permission_status(agent_execution_permission_id: int,
status: Annotated[bool, Body(embed=True)],
user_feedback: Annotated[str, Body(embed=True)] = "",
Authorize: AuthJWT = Depends(check_auth)):
"""
Update the execution permission status of an agent in the database.
This function updates the execution permission status of an agent in the database. The status can be
either "APPROVED" or "REJECTED". The function also updates the user feedback if provided,
commits the changes to the database, and enqueues the agent for execution.
:params:
- agent_execution_permission_id (int): The ID of the agent execution permission
- status (bool): The status of the agent execution permission, True for "APPROVED", False for "REJECTED"
- user_feedback (str): Optional user feedback on the status update
- Authorize (AuthJWT): Dependency function to check user authorization
:return:
- A dictionary containing a "success" key with the value True to indicate a successful update.
"""
agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id)
print(agent_execution_permission)
if agent_execution_permission is None:
raise HTTPException(status_code=400, detail="Invalid Request")
if status is None:
raise HTTPException(status_code=400, detail="Invalid Request status is required")
agent_execution_permission.status = "APPROVED" if status else "REJECTED"
agent_execution_permission.user_feedback = user_feedback.strip() if len(user_feedback.strip()) > 0 else None
db.session.commit()
execute_agent.delay(agent_execution_permission.agent_execution_id, datetime.now())
return {"success": True}
================================================
FILE: superagi/controllers/agent_template.py
================================================
from datetime import datetime
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from main import get_config
from superagi.controllers.types.agent_execution_config import AgentRunIn
from superagi.controllers.types.agent_publish_config import AgentPublish
from superagi.helper.auth import get_user_organisation
from superagi.helper.auth import get_current_user
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_template import AgentTemplate
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.tool import Tool
import json
# from superagi.types.db import AgentTemplateIn, AgentTemplateOut
router = APIRouter()
class AgentTemplateOut(BaseModel):
id: int
organisation_id: int
agent_workflow_id: int
name: str
description: str
marketplace_template_id: int
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class AgentTemplateIn(BaseModel):
organisation_id: int
agent_workflow_id: int
name: str
description: str
marketplace_template_id: int
class Config:
orm_mode = True
@router.get("/get/{agent_template_id}")
def get_agent_template(template_source, agent_template_id: int, organisation=Depends(get_user_organisation)):
"""
Get the details of a specific agent template.
Args:
template_source (str): The source of the agent template ("local" or "marketplace").
agent_template_id (int): The ID of the agent template.
organisation (Depends): Dependency to get the user organisation.
Returns:
dict: The details of the agent template.
Raises:
HTTPException (status_code=404): If the agent template is not found.
"""
if template_source == "local":
db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
AgentTemplate.id == agent_template_id).first()
if not db_agent_template:
raise HTTPException(status_code=404, detail="Agent execution not found")
template = db_agent_template.to_dict()
configs = {}
agent_template_configs = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == agent_template_id).all()
agent_workflow = AgentWorkflow.find_by_id(db_agent_template.agent_workflow_id)
for agent_template_config in agent_template_configs:
config_value = AgentTemplate.eval_agent_config(agent_template_config.key, agent_template_config.value)
configs[agent_template_config.key] = {"value": config_value}
template["configs"] = configs
template["agent_workflow_name"] = agent_workflow.name
else:
template = AgentTemplate.fetch_marketplace_detail(agent_template_id)
return template
@router.put("/update_agent_template/{agent_template_id}", status_code=200)
def edit_agent_template(agent_template_id: int,
updated_agent_configs: dict,
organisation=Depends(get_user_organisation)):
"""
Update the details of an agent template.
Args:
agent_template_id (int): The ID of the agent template to update.
updated_agent_configs (dict): The updated agent configurations.
organisation (Depends): Dependency to get the user organisation.
Returns:
HTTPException (status_code=200): If the agent gets successfully edited.
Raises:
HTTPException (status_code=404): If the agent template is not found.
"""
db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
AgentTemplate.id == agent_template_id).first()
if db_agent_template is None:
raise HTTPException(status_code=404, detail="Agent Template not found")
agent_workflow = AgentWorkflow.find_by_name(db.session, updated_agent_configs["agent_configs"]["agent_workflow"])
db_agent_template.name = updated_agent_configs["name"]
db_agent_template.description = updated_agent_configs["description"]
db_agent_template.agent_workflow_id = agent_workflow.id
db.session.commit()
agent_config_values = updated_agent_configs.get('agent_configs', {})
for key, value in agent_config_values.items():
if isinstance(value, (list, dict)):
value = json.dumps(value)
config = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == agent_template_id,
AgentTemplateConfig.key == key
).first()
if config is not None:
config.value = value
else:
new_config = AgentTemplateConfig(
agent_template_id=agent_template_id,
key=key,
value= value
)
db.session.add(new_config)
db.session.commit()
db.session.flush()
# @router.put("/update_agent_template/{agent_template_id}", status_code=200)
# def edit_agent_template(agent_template_id: int,
# updated_agent_configs: dict,
# organisation=Depends(get_user_organisation)):
# """
# Update the details of an agent template.
# Args:
# agent_template_id (int): The ID of the agent template to update.
# edited_agent_configs (dict): The updated agent configurations.
# organisation (Depends): Dependency to get the user organisation.
# Returns:
# HTTPException (status_code=200): If the agent gets successfully edited.
# Raises:
# HTTPException (status_code=404): If the agent template is not found.
# """
# db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
# AgentTemplate.id == agent_template_id).first()
# if db_agent_template is None:
# raise HTTPException(status_code=404, detail="Agent Template not found")
# db_agent_template.name = updated_agent_configs["name"]
# db_agent_template.description = updated_agent_configs["description"]
# db.session.commit()
# agent_config_values = updated_agent_configs.get('agent_configs', {})
# for key, value in agent_config_values.items():
# if isinstance(value, (list, dict)):
# value = json.dumps(value)
# config = db.session.query(AgentTemplateConfig).filter(
# AgentTemplateConfig.agent_template_id == agent_template_id,
# AgentTemplateConfig.key == key
# ).first()
# if config is not None:
# config.value = value
# else:
# new_config = AgentTemplateConfig(
# agent_template_id=agent_template_id,
# key=key,
# value= value
# )
# db.session.add(new_config)
# db.session.commit()
# db.session.flush()
@router.post("/save_agent_as_template/agent_id/{agent_id}/agent_execution_id/{agent_execution_id}")
def save_agent_as_template(agent_execution_id: str,
agent_id: str,
organisation=Depends(get_user_organisation)):
"""
Save an agent as a template.
Args:
agent_id (str): The ID of the agent to save as a template.
agent_execution_id (str): The ID of the agent execution to save as a template.
organisation (Depends): Dependency to get the user organisation.
Returns:
dict: The saved agent template.
Raises:
HTTPException (status_code=404): If the agent or agent execution configurations are not found.
"""
if agent_execution_id == 'undefined':
raise HTTPException(status_code = 404, detail = "Agent Execution Id undefined")
if agent_id == 'undefined':
raise HTTPException(status_code = 404, detail = "Agent Id undefined")
agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
if agent is None:
raise HTTPException(status_code=404, detail="Agent not found")
configs = None
if agent_execution_id == "-1":
configs = db.session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
if not configs:
raise HTTPException(status_code=404, detail="Agent configurations not found")
else:
configs = db.session.query(AgentExecutionConfiguration).filter(AgentExecutionConfiguration.agent_execution_id == agent_execution_id).all()
if not configs:
raise HTTPException(status_code=404, detail="Agent execution configurations not found")
if configs is None:
raise HTTPException(status_code=404, detail="Configurations not found")
agent_template = AgentTemplate(name=agent.name, description=agent.description,
agent_workflow_id=agent.agent_workflow_id,
organisation_id=organisation.id)
db.session.add(agent_template)
db.session.commit()
for config in configs:
config_value = config.value
if config.key not in AgentTemplate.main_keys():
continue
if config.key == "tools":
config_value = str(Tool.convert_tool_ids_to_names(db, eval(config.value)))
agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=config.key,
value=config_value)
db.session.add(agent_template_config)
db.session.commit()
db.session.flush()
return agent_template.to_dict()
@router.get("/list")
def list_agent_templates(template_source="local", search_str="", page=0, organisation=Depends(get_user_organisation)):
"""
List agent templates.
Args:
template_source (str, optional): The source of the templates ("local" or "marketplace"). Defaults to "local".
search_str (str, optional): The search string to filter templates. Defaults to "".
page (int, optional): The page number for paginated results. Defaults to 0.
organisation (Depends): Dependency to get the user organisation.
Returns:
list: A list of agent templates.
"""
output_json = []
if template_source == "local":
templates = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id).all()
for template in templates:
template.updated_at = template.updated_at.strftime('%d-%b-%Y').upper()
output_json.append(template)
else:
local_templates = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
AgentTemplate.marketplace_template_id != None).all()
local_templates_hash = {}
for local_template in local_templates:
local_templates_hash[local_template.marketplace_template_id] = True
print(local_templates_hash)
templates = AgentTemplate.fetch_marketplace_list(search_str, page)
print(templates)
for template in templates:
template["is_installed"] = local_templates_hash.get(template["id"], False)
template["organisation_id"] = organisation.id
output_json.append(template)
return output_json
@router.get("/marketplace/list")
def list_marketplace_templates(page=0):
"""
Get all marketplace agent templates.
Args:
page (int, optional): The page number for paginated results. Defaults to 0.
Returns:
list: A list of marketplace agent templates.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
page_size = 30
templates = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation_id).offset(
page * page_size).limit(page_size).all()
output_json = []
for template in templates:
template.updated_at = template.updated_at.strftime('%d-%b-%Y').upper()
output_json.append(template)
return output_json
@router.get("/marketplace/template_details/{agent_template_id}")
def marketplace_template_detail(agent_template_id):
"""
Get marketplace template details.
Args:
agent_template_id (int): The ID of the marketplace agent template.
Returns:
dict: A dictionary containing the marketplace template details.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation_id,
AgentTemplate.id == agent_template_id).first()
template_configs = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == template.id).all()
workflow = db.session.query(AgentWorkflow).filter(AgentWorkflow.id == template.agent_workflow_id).first()
tool_configs = {}
for template_config in template_configs:
config_value = AgentTemplate.eval_agent_config(template_config.key, template_config.value)
tool_configs[template_config.key] = {"value": config_value}
output_json = {
"id": template.id,
"name": template.name,
"description": template.description,
"agent_workflow_id": template.agent_workflow_id,
"agent_workflow_name": workflow.name,
"configs": tool_configs
}
return output_json
@router.post("/download", status_code=201)
def download_template(agent_template_id: int,
organisation=Depends(get_user_organisation)):
"""
Create a new agent with configurations.
Args:
agent_template_id (int): The ID of the agent template.
organisation: User's organisation.
Returns:
dict: A dictionary containing the details of the downloaded template.
"""
template = AgentTemplate.clone_agent_template_from_marketplace(db, organisation.id, agent_template_id)
return template.to_dict()
@router.get("/agent_config", status_code=201)
def fetch_agent_config_from_template(agent_template_id: int,
organisation=Depends(get_user_organisation)):
"""
Fetches agent configuration from a template.
Args:
agent_template_id (int): The ID of the agent template.
organisation: User's organisation.
Returns:
dict: A dictionary containing the agent configuration fetched from the template.
Raises:
HTTPException: If the template is not found.
"""
agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.id == agent_template_id,
AgentTemplate.organisation_id == organisation.id).first()
if not agent_template:
raise HTTPException(status_code=404, detail="Template not found")
template_config = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == agent_template_id).all()
template_config_dict = {}
main_keys = AgentTemplate.main_keys()
for config in template_config:
if config.key in main_keys:
template_config_dict[config.key] = AgentTemplate.eval_agent_config(config.key, config.value)
if "instruction" not in template_config_dict:
template_config_dict["instruction"] = []
if "constraints" not in template_config_dict:
template_config_dict["constraints"] = []
for key in main_keys:
if key not in template_config_dict:
template_config_dict[key] = ""
template_config_dict["agent_template_id"] = agent_template.id
agent_workflow = AgentWorkflow.find_by_id(db.session, agent_template.agent_workflow_id)
template_config_dict["agent_workflow"] = agent_workflow.name
return template_config_dict
@router.post("/publish_template/agent_execution_id/{agent_execution_id}", status_code=201)
def publish_template(agent_execution_id: str, organisation=Depends(get_user_organisation), user=Depends(get_current_user)):
"""
Publish an agent execution as a template.
Args:
agent_execution_id (str): The ID of the agent execution to save as a template.
organisation (Depends): Dependency to get the user organisation.
user (Depends): Dependency to get the user.
Returns:
dict: The saved agent template.
Raises:
HTTPException (status_code=404): If the agent or agent execution configurations are not found.
"""
if agent_execution_id == 'undefined':
raise HTTPException(status_code = 404, detail = "Agent Execution Id undefined")
agent_executions = AgentExecution.get_agent_execution_from_id(db.session, agent_execution_id)
if agent_executions is None:
raise HTTPException(status_code = 404, detail = "Agent Execution not found")
agent_id = agent_executions.agent_id
agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
if agent is None:
raise HTTPException(status_code=404, detail="Agent not found")
agent_execution_configurations = db.session.query(AgentExecutionConfiguration).filter(AgentExecutionConfiguration.agent_execution_id == agent_execution_id).all()
if not agent_execution_configurations:
raise HTTPException(status_code=404, detail="Agent execution configurations not found")
agent_template = AgentTemplate(name=agent.name, description=agent.description,
agent_workflow_id=agent.agent_workflow_id,
organisation_id=organisation.id)
db.session.add(agent_template)
db.session.commit()
main_keys = AgentTemplate.main_keys()
for agent_execution_configuration in agent_execution_configurations:
config_value = agent_execution_configuration.value
if agent_execution_configuration.key not in main_keys:
continue
if agent_execution_configuration.key == "tools":
config_value = str(Tool.convert_tool_ids_to_names(db, eval(agent_execution_configuration.value)))
agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=agent_execution_configuration.key,
value=config_value)
db.session.add(agent_template_config)
agent_template_configs = [
AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"),
AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name),
AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)]
db.session.add_all(agent_template_configs)
db.session.commit()
db.session.flush()
return agent_template.to_dict()
@router.post("/publish_template", status_code=201)
def handle_publish_template(updated_details: AgentPublish, organisation=Depends(get_user_organisation), user=Depends(get_current_user)):
"""
Publish a template from edit template page.
Args:
organisation (Depends): Dependency to get the user organisation.
user (Depends): Dependency to get the user.
Returns:
dict: The saved agent template.
Raises:
HTTPException (status_code=404): If the agent template or workflow are not found.
"""
old_template_id = updated_details.agent_template_id
old_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.id==old_template_id, AgentTemplate.organisation_id==organisation.id).first()
if old_agent_template is None:
raise HTTPException(status_code = 404, detail = "Agent Template not found")
agent_workflow_id = old_agent_template.agent_workflow_id
if agent_workflow_id is None:
raise HTTPException(status_code = 404, detail = "Agent Workflow not found")
agent_template = AgentTemplate(name=updated_details.name, description=updated_details.description,
agent_workflow_id=agent_workflow_id,
organisation_id=organisation.id)
db.session.add(agent_template)
db.session.commit()
agent_template_configs = {
"goal": updated_details.goal,
"instruction": updated_details.instruction,
"constraints": updated_details.constraints,
"toolkits": updated_details.toolkits,
"exit": updated_details.exit,
"tools": updated_details.tools,
"iteration_interval": updated_details.iteration_interval,
"model": updated_details.model,
"permission_type": updated_details.permission_type,
"LTM_DB": updated_details.LTM_DB,
"max_iterations": updated_details.max_iterations,
"user_timezone": updated_details.user_timezone,
"knowledge": updated_details.knowledge
}
for key, value in agent_template_configs.items():
if key == "tools":
value = Tool.convert_tool_ids_to_names(db, value)
agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=key, value=str(value))
db.session.add(agent_template_config)
agent_template_configs = [
AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"),
AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name),
AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)]
db.session.add_all(agent_template_configs)
db.session.commit()
db.session.flush()
return agent_template.to_dict()
================================================
FILE: superagi/controllers/agent_workflow.py
================================================
from fastapi import APIRouter
from fastapi import Depends
from fastapi_sqlalchemy import db
from superagi.helper.auth import get_user_organisation
from superagi.models.workflows.agent_workflow import AgentWorkflow
router = APIRouter()
@router.get("/list", status_code=201)
def list_workflows(organisation=Depends(get_user_organisation)):
"""
Lists agent workflows.
Args:
organisation: User's organisation.
Returns:
list: A list of dictionaries representing the agent workflows.
"""
workflows = db.session.query(AgentWorkflow).all()
output_json = []
for workflow in workflows:
output_json.append(workflow.to_dict())
return output_json
================================================
FILE: superagi/controllers/analytics.py
================================================
from fastapi import APIRouter, Depends, HTTPException
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.apm.analytics_helper import AnalyticsHelper
from superagi.apm.event_handler import EventHandler
from superagi.apm.tools_handler import ToolsHandler
from superagi.apm.knowledge_handler import KnowledgeHandler
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
import logging
router = APIRouter()
@router.get("/metrics", status_code=200)
def get_metrics(organisation=Depends(get_user_organisation)):
"""
Get the total tokens, total calls, and the number of run completed.
Returns:
metrics: dictionary containing total tokens, total calls, and the number of runs completed.
"""
try:
return AnalyticsHelper(session=db.session, organisation_id=organisation.id).calculate_run_completed_metrics()
except Exception as e:
logging.error(f"Error while calculating metrics: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/agents/all", status_code=200)
def get_agents(organisation=Depends(get_user_organisation)):
try:
return AnalyticsHelper(session=db.session, organisation_id=organisation.id).fetch_agent_data()
except Exception as e:
logging.error(f"Error while fetching agent data: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/agents/{agent_id}", status_code=200)
def get_agent_runs(agent_id: int, organisation=Depends(get_user_organisation)):
try:
return AnalyticsHelper(session=db.session, organisation_id=organisation.id).fetch_agent_runs(agent_id)
except Exception as e:
logging.error(f"Error while fetching agent runs: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/runs/active", status_code=200)
def get_active_runs(organisation=Depends(get_user_organisation)):
try:
return AnalyticsHelper(session=db.session, organisation_id=organisation.id).get_active_runs()
except Exception as e:
logging.error(f"Error while getting active runs: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/tools/used", status_code=200)
def get_tools_used(organisation=Depends(get_user_organisation)):
try:
return ToolsHandler(session=db.session, organisation_id=organisation.id).calculate_tool_usage()
except Exception as e:
logging.error(f"Error while calculating tool usage: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/tools/{tool_name}/usage", status_code=200)
def get_tool_usage(tool_name: str, organisation=Depends(get_user_organisation)):
try:
return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_usage_by_name(tool_name)
except Exception as e:
if hasattr(e, 'status_code'):
raise HTTPException(status_code=e.status_code, detail=e.detail)
else:
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/knowledge/{knowledge_name}/usage", status_code=200)
def get_knowledge_usage(knowledge_name:str, organisation=Depends(get_user_organisation)):
try:
return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_usage_by_name(knowledge_name)
except Exception as e:
if hasattr(e, 'status_code'):
raise HTTPException(status_code=e.status_code, detail=e.detail)
else:
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/tools/{tool_name}/logs", status_code=200)
def get_tool_logs(tool_name: str, organisation=Depends(get_user_organisation)):
try:
return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_events_by_name(tool_name)
except Exception as e:
logging.error(f"Error while getting tool event details: {str(e)}")
if hasattr(e, 'status_code'):
raise HTTPException(status_code=e.status_code, detail=e.detail)
else:
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/knowledge/{knowledge_name}/logs", status_code=200)
def get_knowledge_logs(knowledge_name: str, organisation=Depends(get_user_organisation)):
try:
return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_events_by_name(knowledge_name)
except Exception as e:
logging.error(f"Error while getting knowledge event details: {str(e)}")
if hasattr(e, 'status_code'):
raise HTTPException(status_code=e.status_code, detail=e.detail)
else:
raise HTTPException(status_code=500, detail="Internal Server Error")
================================================
FILE: superagi/controllers/api/agent.py
================================================
from fastapi import APIRouter
from fastapi import HTTPException, Depends ,Security
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.worker import execute_agent
from superagi.helper.auth import validate_api_key,get_organisation_from_api_key
from superagi.models.agent import Agent
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.project import Project
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent_execution import AgentExecution
from superagi.models.organisation import Organisation
from superagi.models.knowledges import Knowledges
from superagi.models.resource import Resource
from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.helper.s3_helper import S3Helper
from datetime import datetime
from typing import Optional,List
from superagi.models.toolkit import Toolkit
from superagi.apm.event_handler import EventHandler
from superagi.config.config import get_config
router = APIRouter()
class AgentExecutionIn(BaseModel):
name: Optional[str]
goal: Optional[List[str]]
instruction: Optional[List[str]]
class Config:
orm_mode = True
class RunFilterConfigIn(BaseModel):
run_ids:Optional[List[int]]
run_status_filter:Optional[str]
class Config:
orm_mode = True
class ExecutionStateChangeConfigIn(BaseModel):
run_ids:Optional[List[int]]
class Config:
orm_mode = True
class RunIDConfig(BaseModel):
run_ids:List[int]
class Config:
orm_mode = True
@router.post("", status_code=200)
def create_agent_with_config(agent_with_config: AgentConfigExtInput,
api_key: str = Security(validate_api_key), organisation:Organisation = Depends(get_organisation_from_api_key)):
project=Project.find_by_org_id(db.session, organisation.id)
try:
tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,organisation.id,agent_with_config.tools)
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))
agent_with_config.tools=tools_arr
agent_with_config.project_id=project.id
agent_with_config.exit="No exit criterion"
agent_with_config.permission_type="God Mode"
agent_with_config.LTM_DB=None
db_agent = Agent.create_agent_with_config(db, agent_with_config)
if agent_with_config.schedule is not None:
agent_schedule = AgentSchedule.save_schedule_from_config(db.session, db_agent, agent_with_config.schedule)
if agent_schedule is None:
raise HTTPException(status_code=500, detail="Failed to schedule agent")
EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
'model': agent_with_config.model}, db_agent.id,
organisation.id if organisation else 0)
db.session.commit()
return {
"agent_id": db_agent.id
}
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
# Creating an execution with RUNNING status
execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)
agent_execution_configs = {
"goal": agent_with_config.goal,
"instruction": agent_with_config.instruction,
"constraints": agent_with_config.constraints,
"exit": agent_with_config.exit,
"tools": agent_with_config.tools,
"iteration_interval": agent_with_config.iteration_interval,
"model": agent_with_config.model,
"permission_type": agent_with_config.permission_type,
"LTM_DB": agent_with_config.LTM_DB,
"max_iterations": agent_with_config.max_iterations,
"user_timezone": agent_with_config.user_timezone,
"knowledge": agent_with_config.knowledge
}
db.session.add(execution)
db.session.commit()
db.session.flush()
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
agent_execution_configs=agent_execution_configs)
organisation = db_agent.get_agent_organisation(db.session)
EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
'model': agent_with_config.model}, db_agent.id,
organisation.id if organisation else 0)
# execute_agent.delay(execution.id, datetime.now())
db.session.commit()
return {
"agent_id": db_agent.id
}
@router.post("/{agent_id}/run",status_code=200)
def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
agent=Agent.get_agent_from_id(db.session,agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project=Project.find_by_id(db.session, agent.project_id)
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")
db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id)
if db_schedule is not None:
raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run")
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED")
if db_agent_execution is None:
db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
agent_id=agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id)
db.session.add(db_agent_execution)
else:
db_agent_execution.status = "RUNNING"
db.session.commit()
db.session.flush()
agent_execution_configs = {}
if agent_execution.goal is not None:
agent_execution_configs = {
"goal": agent_execution.goal,
}
if agent_execution.instruction is not None:
agent_execution_configs["instructions"] = agent_execution.instruction,
if agent_execution_configs != {}:
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
agent_execution_configs=agent_execution_configs)
EventHandler(session=db.session).create_event('run_created',
{'agent_execution_id': db_agent_execution.id,
'agent_execution_name':db_agent_execution.name
},
agent_id,
organisation.id if organisation else 0)
agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_id)
if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
if knowledge_name is not None:
EventHandler(session=db.session).create_event('knowledge_picked',
{'knowledge_name': knowledge_name,
'agent_execution_id': db_agent_execution.id},
agent_id,
organisation.id if organisation else 0
)
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
return {
"run_id":db_agent_execution.id
}
@router.put("/{agent_id}",status_code=200)
def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api_key: str = Security(validate_api_key),
organisation:Organisation = Depends(get_organisation_from_api_key)):
db_agent= Agent.get_active_agent_by_id(db.session, agent_id)
if not db_agent:
raise HTTPException(status_code=404, detail="agent not found")
project=Project.find_by_id(db.session, db_agent.project_id)
if project is None:
raise HTTPException(status_code=404, detail="Project not found")
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")
# db_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "RUNNING")
# if db_execution is not None:
# raise HTTPException(status_code=409, detail="Agent is already running,please pause and then update")
db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id)
if db_schedule is not None:
raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update")
try:
tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,organisation.id,agent_with_config.tools)
except Exception as e:
raise HTTPException(status_code=404,detail=str(e))
if agent_with_config.schedule is not None:
raise HTTPException(status_code=400,detail="Cannot schedule an existing agent")
agent_with_config.tools=tools_arr
agent_with_config.project_id=project.id
agent_with_config.exit="No exit criterion"
agent_with_config.permission_type="God Mode"
agent_with_config.LTM_DB=None
for key,value in agent_with_config.dict().items():
if hasattr(db_agent,key) and value is not None:
setattr(db_agent,key,value)
db.session.commit()
db.session.flush()
start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id)
agent_execution_configs = {
"goal": agent_with_config.goal,
"instruction": agent_with_config.instruction,
"tools":agent_with_config.tools,
"constraints": agent_with_config.constraints,
"iteration_interval": agent_with_config.iteration_interval,
"model": agent_with_config.model,
"max_iterations": agent_with_config.max_iterations,
"agent_workflow": agent_with_config.agent_workflow,
}
agent_configurations = [
AgentConfiguration(agent_id=db_agent.id, key=key, value=str(value))
for key, value in agent_execution_configs.items()
]
db.session.add_all(agent_configurations)
db.session.add(execution)
db.session.commit()
db.session.flush()
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
agent_execution_configs=agent_execution_configs)
db.session.commit()
return {
"agent_id":db_agent.id
}
@router.post("/{agent_id}/run-status")
def get_agent_runs(agent_id:int,filter_config:RunFilterConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
agent= Agent.get_active_agent_by_id(db.session, agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project=Project.find_by_id(db.session, agent.project_id)
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")
db_execution_arr=[]
if filter_config.run_status_filter is not None:
filter_config.run_status_filter=filter_config.run_status_filter.upper()
db_execution_arr=AgentExecution.get_all_executions_by_filter_config(db.session, agent.id, filter_config)
response_arr=[]
for ind_execution in db_execution_arr:
response_arr.append({"run_id":ind_execution.id, "status":ind_execution.status})
return response_arr
@router.post("/{agent_id}/pause",status_code=200)
def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
agent= Agent.get_active_agent_by_id(db.session, agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project=Project.find_by_id(db.session, agent.project_id)
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")
#Checking if the run_ids whose output files are requested belong to the organisation
if execution_state_change_input.run_ids is not None:
try:
AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
except Exception as e:
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING")
if db_execution_arr is not None and execution_state_change_input.run_ids is not None \
and len(db_execution_arr) != len(execution_state_change_input.run_ids):
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
for ind_execution in db_execution_arr:
ind_execution.status="PAUSED"
db.session.commit()
db.session.flush()
return {
"result":"success"
}
@router.post("/{agent_id}/resume",status_code=200)
def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
agent= Agent.get_active_agent_by_id(db.session, agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project=Project.find_by_id(db.session, agent.project_id)
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")
if execution_state_change_input.run_ids is not None:
try:
AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
except Exception as e:
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED")
if db_execution_arr is not None and execution_state_change_input.run_ids is not None\
and len(db_execution_arr) != len(execution_state_change_input.run_ids):
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
for ind_execution in db_execution_arr:
ind_execution.status="RUNNING"
execute_agent.delay(ind_execution.id, datetime.now())
db.session.commit()
db.session.flush()
return {
"result":"success"
}
@router.post("/resources/output",status_code=200)
def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
if get_config('STORAGE_TYPE') != "S3":
raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured")
run_ids_arr=run_id_config.run_ids
if len(run_ids_arr)==0:
raise HTTPException(status_code=404,
detail=f"No execution_id found")
#Checking if the run_ids whose output files are requested belong to the organisation
try:
AgentExecution.validate_run_ids(db.session, run_ids_arr, organisation.id)
except Exception as e:
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
db_resources_arr=Resource.find_by_run_ids(db.session, run_ids_arr)
try:
response_obj=S3Helper().get_download_url_of_resources(db_resources_arr)
except:
raise HTTPException(status_code=401,detail="Invalid S3 credentials")
return response_obj
================================================
FILE: superagi/controllers/api_key.py
================================================
import json
import uuid
from fastapi import APIRouter, Body
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.helper.auth import get_user_organisation, validate_api_key
from superagi.helper.auth import check_auth
from superagi.models.api_key import ApiKey
from typing import Optional, Annotated
router = APIRouter()
class ApiKeyIn(BaseModel):
id: int
name: str
class Config:
orm_mode = True
class ApiKeyDeleteIn(BaseModel):
id: int
class Config:
orm_mode = True
@router.post("")
def create_api_key(name: Annotated[str, Body(embed=True)], Authorize: AuthJWT = Depends(check_auth),
organisation=Depends(get_user_organisation)):
api_key = str(uuid.uuid4())
obj = ApiKey(key=api_key, name=name, org_id=organisation.id)
db.session.add(obj)
db.session.commit()
db.session.flush()
return {"api_key": api_key}
@router.get("/validate")
def validate_api_key(api_key: str = Depends(validate_api_key)):
return {"success": True}
@router.get("")
def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)):
api_keys = ApiKey.get_by_org_id(db.session, organisation.id)
return api_keys
@router.delete("/{api_key_id}")
def delete_api_key(api_key_id: int, Authorize: AuthJWT = Depends(check_auth)):
api_key = ApiKey.get_by_id(db.session, api_key_id)
if api_key is None:
raise HTTPException(status_code=404, detail="API key not found")
ApiKey.delete_by_id(db.session, api_key_id)
return {"success": True}
@router.put("")
def edit_api_key(api_key_in: ApiKeyIn, Authorize: AuthJWT = Depends(check_auth)):
api_key = ApiKey.get_by_id(db.session, api_key_in.id)
if api_key is None:
raise HTTPException(status_code=404, detail="API key not found")
ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name)
return {"success": True}
================================================
FILE: superagi/controllers/budget.py
================================================
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.helper.auth import check_auth
from superagi.models.budget import Budget
# from superagi.types.db import BudgetIn, BudgetOut
router = APIRouter()
class BudgetOut(BaseModel):
id: int
budget: float
cycle: str
class Config:
orm_mode = True
class BudgetIn(BaseModel):
budget: float
cycle: str
class Config:
orm_mode = True
@router.post("/add", response_model=BudgetOut, status_code=201)
def create_budget(budget: BudgetIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new budget.
Args:
budget: Budget details.
Returns:
Budget: Created budget.
"""
new_budget = Budget(
budget=budget.budget,
cycle=budget.cycle
)
db.session.add(new_budget)
db.session.commit()
return new_budget
@router.get("/get/{budget_id}", response_model=BudgetOut)
def get_budget(budget_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get a budget by budget_id.
Args:
budget_id: Budget ID.
Returns:
Budget: Retrieved budget.
"""
db_budget = db.session.query(Budget).filter(Budget.id == budget_id).first()
if not db_budget:
raise HTTPException(status_code=404, detail="budget not found")
return db_budget
@router.put("/update/{budget_id}", response_model=BudgetOut)
def update_budget(budget_id: int, budget: BudgetIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update budget details by budget_id.
Args:
budget_id: Budget ID.
budget: Updated budget details.
Returns:
Budget: Updated budget.
"""
db_budget = db.session.query(Budget).filter(Budget.id == budget_id).first()
if not db_budget:
raise HTTPException(status_code=404, detail="budget not found")
db_budget.budget = budget.budget
db_budget.cycle = budget.cycle
db.session.commit()
return db_budget
================================================
FILE: superagi/controllers/config.py
================================================
from datetime import datetime
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration
from superagi.models.organisation import Organisation
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Request
from superagi.config.config import get_config
from superagi.helper.auth import check_auth
from fastapi_jwt_auth import AuthJWT
from superagi.helper.encyption_helper import encrypt_data,decrypt_data
from superagi.lib.logger import logger
# from superagi.types.db import ConfigurationIn, ConfigurationOut
router = APIRouter()
class ConfigurationOut(BaseModel):
id: int
organisation_id: int
key: str
value: str
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class ConfigurationIn(BaseModel):
organisation_id: Optional[int]
key: str
value: str
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add/organisation/{organisation_id}", status_code=201,
response_model=ConfigurationOut)
def create_config(config: ConfigurationIn, organisation_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Creates a new Organisation level config.
Args:
config (Configuration): Configuration details.
organisation_id (int): ID of the organisation.
Returns:
Configuration: Created configuration.
"""
db_organisation = db.session.query(Organisation).filter(Organisation.id == organisation_id).first()
if not db_organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
existing_config = (
db.session.query(Configuration)
.filter(Configuration.organisation_id == organisation_id, Configuration.key == config.key)
.first()
)
# Encrypt the API key
if config.key == "model_api_key":
encrypted_value = encrypt_data(config.value)
config.value = encrypted_value
if existing_config:
existing_config.value = config.value
db.session.commit()
db.session.flush()
return existing_config
logger.info("NEW CONFIG")
new_config = Configuration(organisation_id=organisation_id, key=config.key, value=config.value)
logger.info(new_config)
logger.info("ORGANISATION ID : ", organisation_id)
db.session.add(new_config)
db.session.commit()
db.session.flush()
return new_config
@router.get("/get/organisation/{organisation_id}/key/{key}", status_code=200)
def get_config_by_organisation_id_and_key(organisation_id: int, key: str,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get a configuration by organisation ID and key.
Args:
organisation_id (int): ID of the organisation.
key (str): Key of the configuration.
Authorize (AuthJWT, optional): Authorization JWT token. Defaults to Depends(check_auth).
Returns:
Configuration: Retrieved configuration.
"""
db_organisation = db.session.query(Organisation).filter(Organisation.id == organisation_id).first()
if not db_organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
config = db.session.query(ModelsConfig).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == 'OpenAI').first()
if config is None:
api_key = get_config("OPENAI_API_KEY") or get_config("PALM_API_KEY")
if (api_key is not None and api_key != "YOUR_OPEN_API_KEY") or (
api_key is not None and api_key != "YOUR_PALM_API_KEY"):
encrypted_data = encrypt_data(api_key)
new_config = Configuration(organisation_id=organisation_id, key="model_api_key",value=encrypted_data)
db.session.add(new_config)
db.session.commit()
db.session.flush()
return new_config
return config
@router.get("/get/organisation/{organisation_id}", status_code=201)
def get_config_by_organisation_id(organisation_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get all configurations for a given organisation ID.
Args:
organisation_id (int): ID of the organisation.
Authorize (AuthJWT, optional): Authorization JWT token. Defaults to Depends(check_auth).
Returns:
List[Configuration]: List of configurations for the organisation.
"""
db_organisation = db.session.query(Organisation).filter(Organisation.id == organisation_id).first()
if not db_organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
configs = db.session.query(Configuration).filter(Configuration.organisation_id == organisation_id).all()
# Decrypt the API key if the key is "model_api_key"
for config in configs:
if config.key == "model_api_key":
decrypted_value = decrypt_data(config.value)
config.value = decrypted_value
return configs
@router.get("/get/env", status_code=200)
def current_env():
"""
Get the current environment.
Returns:
dict: Dictionary containing the current environment.
"""
env = get_config("ENV", "DEV")
return {
"env": env
}
================================================
FILE: superagi/controllers/google_oauth.py
================================================
from fastapi import Depends, Query
from fastapi import APIRouter
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from sqlalchemy.orm import sessionmaker
from fastapi import HTTPException
import superagi
import json
import requests
from datetime import datetime, timedelta
from superagi.models.db import connect_db
import http.client as http_client
from superagi.helper.auth import get_current_user, check_auth
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.config.config import get_config
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
router = APIRouter()
@router.get('/oauth-tokens')
async def google_auth_calendar(code: str = Query(...), state: str = Query(...)):
toolkit_id = int(state)
client_id = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_ID", ToolConfig.toolkit_id == toolkit_id).first()
if(is_encrypted(client_id.value)):
client_id = decrypt_data(client_id.value)
else:
client_id = client_id.value
client_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_SECRET", ToolConfig.toolkit_id == toolkit_id).first()
if(is_encrypted(client_secret.value)):
client_secret = decrypt_data(client_secret.value)
else:
client_secret = client_secret.value
token_uri = 'https://oauth2.googleapis.com/token'
scope = 'https://www.googleapis.com/auth/calendar'
env = get_config("ENV", "DEV")
if env == "DEV":
redirect_uri = "http://localhost:3000/api/google/oauth-tokens"
else:
redirect_uri = "https://app.superagi.com/api/google/oauth-tokens"
params = {
'client_id': client_id,
'client_secret': client_secret,
'redirect_uri': redirect_uri,
'scope': scope,
'grant_type': 'authorization_code',
'code': code,
'access_type': 'offline',
'approval_prompt': 'force'
}
response = requests.post(token_uri, data=params)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Invalid Client Secret")
response = response.json()
expire_time = datetime.utcnow() + timedelta(seconds=response['expires_in'])
expire_time = expire_time - timedelta(minutes=5)
response['expiry'] = expire_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
response_data = json.dumps(response)
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
redirect_url_success = f"{frontend_url}/google_calendar_creds/?{response_data}"
return RedirectResponse(url=redirect_url_success)
@router.post("/send_google_creds/toolkit_id/{toolkit_id}")
def send_google_calendar_configs(google_creds: dict, toolkit_id: int, Authorize: AuthJWT = Depends(check_auth)):
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
current_user = get_current_user(Authorize)
user_id = current_user.id
toolkit = db.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
google_creds = json.dumps(google_creds)
print(google_creds)
tokens = OauthTokens().add_or_update(session, toolkit_id, user_id, toolkit.organisation_id, "GOOGLE_CALENDAR_OAUTH_TOKENS", google_creds)
if tokens:
success = True
else:
success = False
return success
@router.get("/get_google_creds/toolkit_id/{toolkit_id}")
def get_google_calendar_tool_configs(toolkit_id: int):
google_calendar_config = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,
ToolConfig.key == "GOOGLE_CLIENT_ID").first()
if is_encrypted(google_calendar_config.value):
google_calendar_config.value = decrypt_data(google_calendar_config.value)
return {
"client_id": google_calendar_config.value
}
================================================
FILE: superagi/controllers/knowledge_configs.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Query, status
from fastapi import APIRouter
from superagi.config.config import get_config
from superagi.helper.auth import check_auth
from superagi.models.knowledge_configs import KnowledgeConfigs
from fastapi_jwt_auth import AuthJWT
router = APIRouter()
@router.get("/marketplace/details/{knowledge_id}")
def get_marketplace_knowledge_configs(knowledge_id: int):
knowledge_configs = db.session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
return knowledge_configs
================================================
FILE: superagi/controllers/knowledges.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Query, status
from fastapi import APIRouter
from datetime import datetime
from superagi.config.config import get_config
from superagi.helper.auth import get_user_organisation
from superagi.models.knowledges import Knowledges
from superagi.models.marketplace_stats import MarketPlaceStats
from superagi.models.knowledge_configs import KnowledgeConfigs
from superagi.models.vector_db_indices import VectordbIndices
from superagi.models.vector_dbs import Vectordbs
from superagi.helper.s3_helper import S3Helper
from superagi.models.vector_db_configs import VectordbConfigs
from superagi.vector_store.vector_factory import VectorFactory
from superagi.vector_embeddings.vector_embedding_factory import VectorEmbeddingFactory
from superagi.helper.time_helper import get_time_difference
router = APIRouter()
@router.get("/get/list")
def get_knowledge_list(
page: int = Query(None, title="Page Number"),
organisation = Depends(get_user_organisation)
):
"""
Get Marketplace Knowledge list.
Args:
page (int, optional): The page number for pagination. Defaults to None.
Returns:
dict: The response containing the marketplace list.
"""
if page < 0:
page = 0
marketplace_knowledges = Knowledges.fetch_marketplace_list(page)
marketplace_knowledges_with_install = Knowledges.get_knowledge_install_details(db.session, marketplace_knowledges, organisation)
for knowledge in marketplace_knowledges_with_install:
knowledge["install_number"] = MarketPlaceStats.get_knowledge_installation_number(knowledge["id"])
return marketplace_knowledges_with_install
@router.get("/marketplace/list/{page}")
def get_marketplace_knowledge_list(page: int = 0):
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
page_size = 30
# Apply search filter if provided
query = db.session.query(Knowledges).filter(Knowledges.organisation_id == organisation_id)
if page < 0:
knowledges = query.all()
# Paginate the results
knowledges = query.offset(page * page_size).limit(page_size).all()
return knowledges
@router.get("/user/list")
def get_user_knowledge_list(organisation = Depends(get_user_organisation)):
marketplace_knowledges = Knowledges.fetch_marketplace_list(page=0)
user_knowledge_list = Knowledges.get_organisation_knowledges(db.session, organisation)
for user_knowledge in user_knowledge_list:
if user_knowledge["name"] in [knowledge['name'] for knowledge in marketplace_knowledges]:
user_knowledge["is_marketplace"] = True
else:
user_knowledge["is_marketplace"] = False
return user_knowledge_list
@router.get("/marketplace/get/details/{knowledge_name}")
def get_knowledge_details(knowledge_name: str):
knowledge_data = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
knowledge_config_data = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge_data["id"])
knowledge_data_with_config = knowledge_data | knowledge_config_data
knowledge_data_with_config["install_number"] = MarketPlaceStats.get_knowledge_installation_number(knowledge_data_with_config["id"])
update_time = str(knowledge_data_with_config["updated_at"])
update_time = datetime.strptime(update_time, "%Y-%m-%dT%H:%M:%S.%f")
knowledge_data_with_config["updated_at"] = datetime.strftime(update_time, '%d %B %Y')
return knowledge_data_with_config
@router.get("/marketplace/details/{knowledge_name}")
def get_marketplace_knowledge_details(knowledge_name: str):
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
knowledge_details = db.session.query(Knowledges).filter(Knowledges.name == knowledge_name, Knowledges.organisation_id == organisation_id).first()
return knowledge_details
@router.get("/user/get/details/{knowledge_id}")
def get_user_knowledge_details(knowledge_id: int):
knowledge_data = Knowledges.get_knowledge_from_id(db.session, knowledge_id)
vector_database_index = VectordbIndices.get_vector_index_from_id(db.session, knowledge_data.vector_db_index_id)
vector_database = Vectordbs.get_vector_db_from_id(db.session, vector_database_index.vector_db_id)
knowledge = {
"name": knowledge_data.name,
"description": knowledge_data.description,
"vector_database_index": {
"id": vector_database_index.id,
"name": vector_database_index.name
},
"vector_database": vector_database.name,
"installation_type": vector_database_index.state
}
knowledge_config = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(db.session, knowledge_id)
knowledge_data_with_config = knowledge | knowledge_config
return knowledge_data_with_config
@router.post("/add_or_update/data")
def add_update_user_knowledge(knowledge_data: dict, organisation = Depends(get_user_organisation)):
knowledge_data["organisation_id"] = organisation.id
knowledge_data["contributed_by"] = organisation.name
knowledge = Knowledges.add_update_knowledge(db.session, knowledge_data)
if not knowledge:
raise HTTPException(status_code=404, detail="Knowledge not found")
return {"id": knowledge.id}
@router.post("/delete/{knowledge_id}")
def delete_user_knowledge(knowledge_id: int):
try:
Knowledges.delete_knowledge(db.session, knowledge_id)
except:
raise HTTPException(status_code=404, detail="Knowledge not found")
@router.get("/install/{knowledge_name}/index/{vector_db_index_id}")
def install_selected_knowledge(knowledge_name: str, vector_db_index_id: int, organisation = Depends(get_user_organisation)):
vector_db_index = VectordbIndices.get_vector_index_from_id(db.session, vector_db_index_id)
selected_knowledge = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
selected_knowledge_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(selected_knowledge['id'])
file_chunks = S3Helper().get_json_file(selected_knowledge_config["file_path"])
vector = Vectordbs.get_vector_db_from_id(db.session, vector_db_index.vector_db_id)
db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector.id)
upsert_data = VectorEmbeddingFactory.build_vector_storage(vector.db_type, file_chunks).get_vector_embeddings_from_chunks()
try:
vector_db_storage = VectorFactory.build_vector_storage(vector.db_type, vector_db_index.name, **db_creds)
vector_db_storage.add_embeddings_to_vector_db(upsert_data)
except Exception as err:
raise HTTPException(status_code=400, detail=err)
selected_knowledge_data = {
"id": -1,
"name": selected_knowledge["name"],
"description": selected_knowledge["description"],
"index_id": vector_db_index_id,
"organisation_id": organisation.id,
"contributed_by": selected_knowledge["contributed_by"],
}
new_knowledge = Knowledges.add_update_knowledge(db.session, selected_knowledge_data)
removable_key = 'file_path'
selected_knowledge_config.pop(removable_key)
configs = selected_knowledge_config
KnowledgeConfigs.add_update_knowledge_config(db.session, new_knowledge.id, configs)
VectordbIndices.update_vector_index_state(db.session, vector_db_index_id, "Marketplace")
install_number = MarketPlaceStats.get_knowledge_installation_number(selected_knowledge["id"])
MarketPlaceStats.update_knowledge_install_number(db.session, selected_knowledge["id"], int(install_number) + 1)
@router.post("/uninstall/{knowledge_name}")
def uninstall_selected_knowledge(knowledge_name: str, organisation = Depends(get_user_organisation)):
knowledge = db.session.query(Knowledges).filter(Knowledges.name == knowledge_name, Knowledges.organisation_id == organisation.id).first()
knowledge_config = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(db.session, knowledge.id)
vector_ids = eval(knowledge_config["vector_ids"])
vector_db_index = VectordbIndices.get_vector_index_from_id(db.session, knowledge.vector_db_index_id)
vector = Vectordbs.get_vector_db_from_id(db.session, vector_db_index.vector_db_id)
db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector.id)
try:
vector_db_storage = VectorFactory.build_vector_storage(vector.db_type, vector_db_index.name, **db_creds)
vector_db_storage.delete_embeddings_from_vector_db(vector_ids)
except Exception as err:
raise HTTPException(status_code=400, detail=err)
KnowledgeConfigs.delete_knowledge_config(db.session, knowledge.id)
Knowledges.delete_knowledge(db.session, knowledge.id)
================================================
FILE: superagi/controllers/marketplace_stats.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Query, status
from fastapi import APIRouter
from superagi.config.config import get_config
from superagi.models.marketplace_stats import MarketPlaceStats
from superagi.models.vector_dbs import Vectordbs
router = APIRouter()
@router.get("/knowledge/downloads/{knowledge_id}")
def count_knowledge_downloads(knowledge_id: int):
download_number = db.session.query(MarketPlaceStats).filter(MarketPlaceStats.reference_id == knowledge_id, MarketPlaceStats.reference_name == "KNOWLEDGE", MarketPlaceStats.key == "download_count").first()
if download_number is None:
downloads = 0
else:
downloads = download_number.value
return downloads
================================================
FILE: superagi/controllers/models_controller.py
================================================
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.helper.models_helper import ModelsHelper
from superagi.apm.call_log_helper import CallLogHelper
from superagi.lib.logger import logger
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
from superagi.helper.llm_loader import LLMLoader
router = APIRouter()
class ValidateAPIKeyRequest(BaseModel):
model_provider: str
model_api_key: str
class StoreModelRequest(BaseModel):
model_name: str
description: str
end_point: str
model_provider_id: int
token_limit: int
type: str
version: str
context_length: Optional[int]
class ModelName (BaseModel):
model: str
@router.post("/store_api_keys", status_code=200)
async def store_api_keys(request: ValidateAPIKeyRequest, organisation=Depends(get_user_organisation)):
try:
return ModelsConfig.store_api_key(db.session, organisation.id, request.model_provider, request.model_api_key)
except Exception as e:
logging.error(f"Error while storing API key: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/get_api_keys")
async def get_api_keys(organisation=Depends(get_user_organisation)):
try:
return ModelsConfig.fetch_api_keys(db.session, organisation.id)
except Exception as e:
logging.error(f"Error while retrieving API Keys: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/get_api_key", status_code=200)
async def get_api_key(model_provider: str = None, organisation=Depends(get_user_organisation)):
try:
return ModelsConfig.fetch_api_key(db.session, organisation.id, model_provider)
except Exception as e:
logging.error(f"Error while retrieving API Key: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/verify_end_point", status_code=200)
async def verify_end_point(model_api_key: str = None, end_point: str = None, model_provider: str = None):
try:
return ModelsHelper.validate_end_point(model_api_key, end_point, model_provider)
except Exception as e:
logging.error(f"Error validating Endpoint: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.post("/store_model", status_code=200)
async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)):
try:
#context_length = 4096
logger.info(request)
if 'context_length' in request.dict():
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length)
else:
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, 0)
except Exception as e:
logging.error(f"Error storing the Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/fetch_models", status_code=200)
async def fetch_models(organisation=Depends(get_user_organisation)):
try:
return Models.fetch_models(db.session, organisation.id,)
except Exception as e:
logging.error(f"Error Fetching Models: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/fetch_model/{model_id}", status_code=200)
async def fetch_model_details(model_id: int, organisation=Depends(get_user_organisation)):
try:
return Models.fetch_model_details(db.session, organisation.id, model_id)
except Exception as e:
logging.error(f"Error Fetching Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.post("/fetch_model_data", status_code=200)
async def fetch_data(request: ModelName, organisation=Depends(get_user_organisation)):
try:
return CallLogHelper(session=db.session, organisation_id=organisation.id).fetch_data(request.model)
except Exception as e:
logging.error(f"Error Fetching Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/get/list", status_code=200)
def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Get Marketplace Model list.
Args:
page (int, optional): The page number for pagination. Defaults to None.
Returns:
dict: The response containing the marketplace list.
"""
if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation.id)
return marketplace_models_with_install
@router.get("/marketplace/list/{page}", status_code=200)
def get_marketplace_models_list(page: int = 0):
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
if organisation_id is not None:
organisation_id = int(organisation_id)
page_size = 16
query = db.session.query(Models).filter(Models.org_id == organisation_id)
if page < 0:
models = query.all()
else:
models = query.offset(page * page_size).limit(page_size).all()
models_list = []
for model in models:
model_dict = model.__dict__
model_dict["provider"] = db.session.query(ModelsConfig).filter(ModelsConfig.id == model.model_provider_id).first().provider
models_list.append(model_dict)
return models_list
@router.get("/get/models_details", status_code=200)
def get_models_details(page: int = 0):
"""
Get Marketplace Model list.
Args:
page (int, optional): The page number for pagination. Defaults to None.
Returns:
dict: The response containing the marketplace list.
"""
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
if organisation_id is not None:
organisation_id = int(organisation_id)
if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
ModelsTypes.MARKETPLACE.value)
return marketplace_models_with_install
@router.get("/test_local_llm", status_code=200)
def test_local_llm():
try:
llm_loader = LLMLoader(context_length=4096)
llm_model = llm_loader.model
llm_grammar = llm_loader.grammar
if llm_model is None:
logger.error("Model not found.")
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
if llm_grammar is None:
logger.error("Grammar not found.")
raise HTTPException(status_code=404, detail="Grammar not found.")
messages = [
{"role":"system",
"content":"You are an AI assistant. Give response in a proper JSON format"},
{"role":"user",
"content":"Hi!"}
]
response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar)
content = response["choices"][0]["message"]["content"]
logger.info(content)
return "Model loaded successfully."
except Exception as e:
logger.info("Error: ",e)
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
================================================
FILE: superagi/controllers/organisation.py
================================================
from datetime import datetime
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.helper.auth import get_user_organisation
from superagi.helper.auth import check_auth
from superagi.helper.encyption_helper import decrypt_data
from superagi.helper.tool_helper import register_toolkits
from superagi.llms.google_palm import GooglePalm
from superagi.llms.llm_model_factory import build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.models.configuration import Configuration
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.user import User
from superagi.lib.logger import logger
from superagi.models.workflows.agent_workflow import AgentWorkflow
# from superagi.types.db import OrganisationIn, OrganisationOut
router = APIRouter()
class OrganisationOut(BaseModel):
id: int
name: str
description: str
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class OrganisationIn(BaseModel):
name: str
description: str
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=OrganisationOut, status_code=201)
def create_organisation(organisation: OrganisationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new organisation.
Args:
organisation: Organisation data.
Returns:
dict: Dictionary containing the created organisation.
Raises:
HTTPException (status_code=400): If there is an issue creating the organisation.
"""
new_organisation = Organisation(
name=organisation.name,
description=organisation.description,
)
db.session.add(new_organisation)
db.session.commit()
db.session.flush()
register_toolkits(session=db.session, organisation=new_organisation)
logger.info(new_organisation)
return new_organisation
@router.get("/get/{organisation_id}", response_model=OrganisationOut)
def get_organisation(organisation_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Get organisation details by organisation_id.
Args:
organisation_id: ID of the organisation.
Returns:
dict: Dictionary containing the organisation details.
Raises:
HTTPException (status_code=404): If the organisation with the specified ID is not found.
"""
db_organisation = db.session.query(Organisation).filter(Organisation.id == organisation_id).first()
if not db_organisation:
raise HTTPException(status_code=404, detail="organisation not found")
return db_organisation
@router.put("/update/{organisation_id}", response_model=OrganisationOut)
def update_organisation(organisation_id: int, organisation: OrganisationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update organisation details by organisation_id.
Args:
organisation_id: ID of the organisation.
organisation: Updated organisation data.
Returns:
dict: Dictionary containing the updated organisation details.
Raises:
HTTPException (status_code=404): If the organisation with the specified ID is not found.
"""
db_organisation = db.session.query(Organisation).filter(Organisation.id == organisation_id).first()
if not db_organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
db_organisation.name = organisation.name
db_organisation.description = organisation.description
db.session.commit()
return db_organisation
@router.get("/get/user/{user_id}", response_model=OrganisationOut, status_code=201)
def get_organisations_by_user(user_id: int):
"""
Get organisations associated with a user.If Organisation does not exists a new organisation is created
Args:
user_id: ID of the user.
Returns:
dict: Dictionary containing the organisation details.
Raises:
HTTPException (status_code=400): If the user with the specified ID is not found.
"""
user = db.session.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(status_code=400,
detail="User not found")
organisation = Organisation.find_or_create_organisation(db.session, user)
Project.find_or_create_default_project(db.session, organisation.id)
return organisation
@router.get("/llm_models")
def get_llm_models(organisation=Depends(get_user_organisation)):
"""
Get all the llm models associated with an organisation.
Args:
organisation: Organisation data.
"""
model_api_key = db.session.query(Configuration).filter(Configuration.organisation_id == organisation.id,
Configuration.key == "model_api_key").first()
model_source = db.session.query(Configuration).filter(Configuration.organisation_id == organisation.id,
Configuration.key == "model_source").first()
if model_api_key is None or model_source is None:
raise HTTPException(status_code=400,
detail="Organisation not found")
decrypted_api_key = decrypt_data(model_api_key.value)
model = build_model_with_api_key(model_source.value, decrypted_api_key)
models = model.get_models() if model is not None else []
return models
@router.get("/agent_workflows")
def agent_workflows(organisation=Depends(get_user_organisation)):
"""
Get all the agent workflows
Args:
organisation: Organisation data.
"""
agent_workflows = db.session.query(AgentWorkflow).all()
workflows = [workflow.name for workflow in agent_workflows]
return workflows
================================================
FILE: superagi/controllers/project.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel
from superagi.models.project import Project
from superagi.models.organisation import Organisation
from fastapi import APIRouter
from superagi.helper.auth import check_auth
from superagi.lib.logger import logger
# from superagi.types.db import ProjectIn, ProjectOut
router = APIRouter()
class ProjectOut(BaseModel):
id: int
name: str
organisation_id: int
description: str
class Config:
orm_mode = True
class ProjectIn(BaseModel):
name: str
organisation_id: int
description: str
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=ProjectOut, status_code=201)
def create_project(project: ProjectIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new project.
Args:
project (Project): Project data.
Returns:
dict: Dictionary containing the created project.
Raises:
HTTPException (status_code=404): If the organization with the specified ID is not found.
"""
logger.info("Organisation_id : ", project.organisation_id)
organisation = db.session.query(Organisation).get(project.organisation_id)
if not organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
project = Project(
name=project.name,
organisation_id=organisation.id,
description=project.description
)
db.session.add(project)
db.session.commit()
return project
@router.get("/get/{project_id}", response_model=ProjectOut)
def get_project(project_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Get project details by project_id.
Args:
project_id (int): ID of the project.
Returns:
dict: Dictionary containing the project details.
Raises:
HTTPException (status_code=404): If the project with the specified ID is not found.
"""
db_project = db.session.query(Project).filter(Project.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="project not found")
return db_project
@router.put("/update/{project_id}", response_model=ProjectOut)
def update_project(project_id: int, project: ProjectIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a project detail by project_id.
Args:
project_id (int): ID of the project.
project (Project): Updated project data.
Returns:
dict: Dictionary containing the updated project details.
Raises:
HTTPException (status_code=404): If the project with the specified ID is not found.
HTTPException (status_code=404): If the organization with the specified ID is not found.
"""
db_project = db.session.query(Project).get(project_id)
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
if project.organisation_id:
organisation = db.session.query(Organisation).get(project.organisation_id)
if not organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
db_project.organisation_id = organisation.id
db_project.name = project.name
db_project.description = project.description
db.session.add(db_project)
db.session.commit()
return db_project
@router.get("/get/organisation/{organisation_id}")
def get_projects_organisation(organisation_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get all projects by organisation_id and create default if no project.
Args:
organisation_id (int): ID of the organisation.
Returns:
List[Project]: List of projects belonging to the organisation.
Raises:
HTTPException (status_code=404): If the organization with the specified ID is not found.
"""
Project.find_or_create_default_project(db.session, organisation_id)
projects = db.session.query(Project).filter(Project.organisation_id == organisation_id).all()
if len(projects) <= 0:
default_project = Project.find_or_create_default_project(db.session, organisation_id)
projects.append(default_project)
return projects
================================================
FILE: superagi/controllers/resources.py
================================================
import datetime
import os
from pathlib import Path
import boto3
from botocore.exceptions import NoCredentialsError
from fastapi import APIRouter
from fastapi import File, Form, UploadFile
from fastapi import HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from superagi.config.config import get_config
from superagi.helper.auth import check_auth
from superagi.helper.resource_helper import ResourceHelper
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.resource import Resource
from superagi.worker import summarize_resource
from superagi.types.storage_types import StorageType
router = APIRouter()
s3 = boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
@router.post("/add/{agent_id}", status_code=201)
async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), size=Form(...), type=Form(...),
Authorize: AuthJWT = Depends(check_auth)):
"""
Upload a file as a resource for an agent.
Args:
agent_id (int): ID of the agent.
file (UploadFile): Uploaded file.
name (str): Name of the resource.
size (str): Size of the resource.
type (str): Type of the resource.
Returns:
Resource: Uploaded resource.
Raises:
HTTPException (status_code=400): If the agent with the specified ID does not exist.
HTTPException (status_code=400): If the file type is not supported.
HTTPException (status_code=500): If AWS credentials are not found or if there is an issue uploading to S3.
"""
agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
if agent is None:
raise HTTPException(status_code=400, detail="Agent does not exists")
# accepted_file_types is a tuple because endswith() expects a tuple
accepted_file_types = (".pdf", ".docx", ".pptx", ".csv", ".txt", ".epub")
if not name.endswith(accepted_file_types):
raise HTTPException(status_code=400, detail="File type not supported!")
storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value))
save_directory = ResourceHelper.get_root_input_dir()
if "{agent_id}" in save_directory:
save_directory = ResourceHelper.get_formatted_agent_level_path(agent=Agent
.get_agent_from_id(session=db.session,
agent_id=agent_id),
path=save_directory)
file_path = os.path.join(save_directory, file.filename)
if storage_type == StorageType.FILE:
os.makedirs(save_directory, exist_ok=True)
with open(file_path, "wb") as f:
contents = await file.read()
f.write(contents)
file.file.close()
elif storage_type == StorageType.S3:
bucket_name = get_config("BUCKET_NAME")
file_path = 'resources' + file_path
try:
s3.upload_fileobj(file.file, bucket_name, file_path)
logger.info("File uploaded successfully!")
except NoCredentialsError:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
resource = Resource(name=name, path=file_path, storage_type=storage_type.value, size=size, type=type, channel="INPUT",
agent_id=agent.id)
db.session.add(resource)
db.session.commit()
db.session.flush()
summarize_resource.delay(agent_id, resource.id)
logger.info(resource)
return resource
@router.get("/get/all/{agent_id}", status_code=200)
def get_all_resources(agent_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get all resources for an agent.
Args:
agent_id (int): ID of the agent.
Returns:
List[Resource]: List of resources belonging to the agent.
"""
resources = db.session.query(Resource).filter(Resource.agent_id == agent_id).all()
return resources
@router.get("/get/{resource_id}", status_code=200)
def download_file_by_id(resource_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Download a particular resource by resource_id.
Args:
resource_id (int): ID of the resource.
Authorize (AuthJWT, optional): Authorization dependency.
Returns:
StreamingResponse: Streaming response for downloading the resource.
Raises:
HTTPException (status_code=400): If the resource with the specified ID is not found.
HTTPException (status_code=403): If the user doesn't have permission to access this resource.
HTTPException (status_code=404): If the file is not found.
"""
# Get current user's organization_id from JWT token
current_user_org_id = Authorize.get_jwt_subject()
# First check if resource exists
resource = db.session.query(Resource).filter(Resource.id == resource_id).first()
if not resource:
raise HTTPException(status_code=400, detail="Resource Not found!")
# Get the agent that owns this resource
agent = db.session.query(Agent).filter(Agent.id == resource.agent_id).first()
if not agent:
raise HTTPException(status_code=400, detail="Associated agent not found!")
# Verify the authenticated user belongs to the same organization as the agent
if str(agent.organisation_id) != str(current_user_org_id):
raise HTTPException(status_code=403, detail="You don't have permission to access this resource")
download_file_path = resource.path
file_name = resource.name
if resource.storage_type == StorageType.S3.value:
bucket_name = get_config("BUCKET_NAME")
file_key = resource.path
response = s3.get_object(Bucket=bucket_name, Key=file_key)
content = response["Body"]
else:
abs_file_path = Path(download_file_path).resolve()
if not abs_file_path.is_file():
raise HTTPException(status_code=404, detail="File not found")
content = open(str(abs_file_path), "rb")
return StreamingResponse(
content,
media_type="application/octet-stream",
headers={
"Content-Disposition": f"attachment; filename={file_name}"
}
)
================================================
FILE: superagi/controllers/tool.py
================================================
from datetime import datetime
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
router = APIRouter()
class ToolOut(BaseModel):
id: int
name: str
folder_name: str
class_name: str
file_name: str
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class ToolIn(BaseModel):
name: str
folder_name: str
class_name: str
file_name: str
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=ToolOut, status_code=201)
def create_tool(
tool: ToolIn,
Authorize: AuthJWT = Depends(check_auth),
):
"""
Create a new tool.
Args:
tool (ToolIn): Tool data.
Returns:
Tool: The created tool.
Raises:
HTTPException (status_code=400): If there is an issue creating the tool.
"""
db_tool = Tool(
name=tool.name,
folder_name=tool.folder_name,
class_name=tool.class_name,
file_name=tool.file_name,
)
db.session.add(db_tool)
db.session.commit()
return db_tool
@router.get("/get/{tool_id}", response_model=ToolOut)
def get_tool(
tool_id: int,
Authorize: AuthJWT = Depends(check_auth),
):
"""
Get a particular tool details.
Args:
tool_id (int): ID of the tool.
Returns:
Tool: The tool details.
Raises:
HTTPException (status_code=404): If the tool with the specified ID is not found.
"""
db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
if not db_tool:
raise HTTPException(status_code=404, detail="Tool not found")
return db_tool
@router.get("/list")
def get_tools(
organisation: Organisation = Depends(get_user_organisation)):
"""Get all tools"""
toolkits = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
tools = []
for toolkit in toolkits:
db_tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
tools.extend(db_tools)
return tools
@router.put("/update/{tool_id}", response_model=ToolOut)
def update_tool(
tool_id: int,
tool: ToolIn,
Authorize: AuthJWT = Depends(check_auth),
):
"""
Update a particular tool.
Args:
tool_id (int): ID of the tool.
tool (ToolIn): Updated tool data.
Returns:
Tool: The updated tool details.
Raises:
HTTPException (status_code=404): If the tool with the specified ID is not found.
"""
db_tool = db.session.query(Tool).filter(Tool.id == tool_id).first()
if not db_tool:
raise HTTPException(status_code=404, detail="Tool not found")
db_tool.name = tool.name
db_tool.folder_name = tool.folder_name
db_tool.class_name = tool.class_name
db_tool.file_name = tool.file_name
db.session.add(db_tool)
db.session.commit()
return db_tool
================================================
FILE: superagi/controllers/tool_config.py
================================================
from fastapi import APIRouter, HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
from superagi.helper.auth import check_auth
from superagi.helper.auth import get_user_organisation
from superagi.models.organisation import Organisation
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.helper.encyption_helper import encrypt_data
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
from superagi.types.key_type import ToolConfigKeyType
import json
router = APIRouter()
class ToolConfigOut(BaseModel):
id = int
key = str
value = str
toolkit_id = int
class Config:
orm_mode = True
@router.post("/add/{toolkit_name}", status_code=201)
def update_tool_config(toolkit_name: str, configs: list, organisation: Organisation = Depends(get_user_organisation)):
"""
Update tool configurations for a specific tool kit.
Args:
toolkit_name (str): The name of the tool kit.
configs (list): A list of dictionaries containing the tool configurations.
Each dictionary should have the following keys:
- "key" (str): The key of the configuration.
- "value" (str): The new value for the configuration.
Returns:
dict: A dictionary with the message "Tool configs updated successfully".
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
HTTPException (status_code=500): If an unexpected error occurs during the update process.
"""
try:
# Check if the tool kit exists
toolkit = Toolkit.get_toolkit_from_name(db.session, toolkit_name,organisation)
if toolkit is None:
raise HTTPException(status_code=404, detail="Tool kit not found")
# Update existing tool configs
for config in configs:
key = config.get("key")
value = config.get("value")
if value is None:
continue
if key is not None:
tool_config = db.session.query(ToolConfig).filter_by(toolkit_id=toolkit.id, key=key).first()
if tool_config:
if tool_config.key_type == ToolConfigKeyType.FILE.value:
value = json.dumps(value)
# Update existing tool config
# added encryption
tool_config.value = encrypt_data(value)
db.session.commit()
return {"message": "Tool configs updated successfully"}
except Exception as e:
# db.session.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-or-update/{toolkit_name}", status_code=201, response_model=ToolConfigOut)
def create_or_update_tool_config(toolkit_name: str, tool_configs,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create or update tool configurations for a specific tool kit.
Args:
toolkit_name (str): The name of the tool kit.
tool_configs (list): A list of tool configuration objects.
Returns:
Toolkit: The updated tool kit object.
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
"""
toolkit = db.session.query(Toolkit).filter_by(name=toolkit_name).first()
if not toolkit:
raise HTTPException(status_code=404, detail='ToolKit not found')
# Iterate over the tool_configs list
for tool_config in tool_configs:
existing_tool_config = db.session.query(ToolConfig).filter(
ToolConfig.toolkit_id == toolkit.id,
ToolConfig.key == tool_config.key
).first()
if existing_tool_config.value:
# Update the existing tool config
if existing_tool_config.key_type == ToolConfigKeyType.FILE.value:
existing_tool_config.value = json.dumps(existing_tool_config.value)
existing_tool_config.value = encrypt_data(tool_config.value)
else:
# Create a new tool config
new_tool_config = ToolConfig(key=tool_config.key, value=encrypt_data(tool_config.value), toolkit_id=toolkit.id)
db.session.add(new_tool_config)
db.session.commit()
db.session.refresh(toolkit)
return toolkit
@router.get("/get/toolkit/{toolkit_name}", status_code=200)
def get_all_tool_configs(toolkit_name: str, organisation: Organisation = Depends(get_user_organisation)):
"""
Get all tool configurations by Tool Kit Name.
Args:
toolkit_name (str): The name of the tool kit.
organisation (Organisation): The organization associated with the user.
Returns:
list: A list of tool configurations for the specified tool kit.
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
HTTPException (status_code=403): If the user is not authorized to access the tool kit.
"""
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Toolkit.organisation_id == organisation.id).first()
if not toolkit:
raise HTTPException(status_code=404, detail='ToolKit not found')
tool_configs = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit.id).all()
for tool_config in tool_configs:
if tool_config.value:
if(is_encrypted(tool_config.value)):
tool_config.value = decrypt_data(tool_config.value)
if tool_config.key_type == ToolConfigKeyType.FILE.value:
tool_config.value = json.loads(tool_config.value)
return tool_configs
@router.get("/get/toolkit/{toolkit_name}/key/{key}", status_code=200)
def get_tool_config(toolkit_name: str, key: str, organisation: Organisation = Depends(get_user_organisation)):
"""
Get a specific tool configuration by tool kit name and key.
Args:
toolkit_name (str): The name of the tool kit.
key (str): The key of the tool configuration.
organisation (Organisation): The organization associated with the user.
Returns:
ToolConfig: The tool configuration with the specified key.
Raises:
HTTPException (status_code=403): If the user is not authorized to access the tool kit.
HTTPException (status_code=404): If the specified tool kit or tool configuration is not found.
"""
user_toolkits = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
toolkit = db.session.query(Toolkit).filter_by(name=toolkit_name)
if toolkit not in user_toolkits:
raise HTTPException(status_code=403, detail='Unauthorized')
tool_config = db.session.query(ToolConfig).filter(
ToolConfig.toolkit_id == toolkit.id,
ToolConfig.key == key
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="Tool configuration not found")
if(is_encrypted(tool_config.value)):
tool_config.value = decrypt_data(tool_config.value)
if tool_config.key_type == ToolConfigKeyType.FILE.value:
tool_config.value = json.loads(tool_config.value)
return tool_config
================================================
FILE: superagi/controllers/toolkit.py
================================================
from typing import Optional
import requests
from fastapi import APIRouter, Body
from fastapi import HTTPException, Depends, Query
from fastapi_sqlalchemy import db
from superagi.config.config import get_config
from superagi.helper.auth import get_user_organisation
from superagi.helper.tool_helper import get_readme_content_from_code_link, download_tool, process_files, \
add_tool_to_json
from superagi.helper.github_helper import GithubHelper
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.types.common import GitHubLinkRequest
from superagi.helper.tool_helper import compare_toolkit
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
router = APIRouter()
# marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001/"
# For internal use
@router.get("/marketplace/list/{page}")
def get_marketplace_toolkits(
page: int = 0,
):
"""
Get marketplace tool kits.
Args:
page (int): The page number for pagination.
Returns:
list: A list of tool kits in the marketplace.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
page_size = 30
# Apply search filter if provided
query = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation_id)
# Paginate the results
toolkits = query.offset(page * page_size).limit(page_size).all()
# Fetch tools for each tool kit
for toolkit in toolkits:
toolkit.tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
toolkit.updated_at = toolkit.updated_at.strftime('%d-%b-%Y').upper()
return toolkits
# For internal use
@router.get("/marketplace/details/{toolkit_name}")
def get_marketplace_toolkit_detail(toolkit_name: str):
"""
Get tool kit details from the marketplace.
Args:
toolkit_name (str): The name of the tool kit.
Returns:
Toolkit: The tool kit details from the marketplace.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
toolkit = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation_id,
Toolkit.name == toolkit_name).first()
toolkit.tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
toolkit.configs = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit.id).all()
for tool_configs in toolkit.configs:
if is_encrypted(tool_configs.value):
tool_configs.value = decrypt_data(tool_configs.value)
return toolkit
# For internal use
@router.get("/marketplace/readme/{toolkit_name}")
def get_marketplace_toolkit_readme(toolkit_name: str):
"""
Get tool kit readme from the marketplace.
Args:
toolkit_name (str): The name of the tool kit.
Returns:
str: The content of the tool kit's readme file.
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Toolkit.organisation_id == organisation_id).first()
if not toolkit:
raise HTTPException(status_code=404, detail='ToolKit not found')
return get_readme_content_from_code_link(toolkit.tool_code_link)
# For internal use
@router.get("/marketplace/tools/{toolkit_name}")
def get_marketplace_toolkit_tools(toolkit_name: str):
"""
Get tools of a specific tool kit from the marketplace.
Args:
toolkit_name (str): The name of the tool kit.
Returns:
Tool: The tools associated with the tool kit.
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Toolkit.organisation_id == organisation_id).first()
if not toolkit:
raise HTTPException(status_code=404, detail="ToolKit not found")
tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).first()
return tools
@router.get("/get/install/{toolkit_name}")
def install_toolkit_from_marketplace(toolkit_name: str,
organisation: Organisation = Depends(get_user_organisation)):
"""
Download and install a tool kit from the marketplace.
Args:
toolkit_name (str): The name of the tool kit.
organisation (Organisation): The user's organisation.
Returns:
dict: A message indicating the successful installation of the tool kit.
"""
# Check if the tool kit exists
toolkit = Toolkit.fetch_marketplace_detail(search_str="details",
toolkit_name=toolkit_name)
db_toolkit = Toolkit.add_or_update(session=db.session, name=toolkit['name'], description=toolkit['description'],
tool_code_link=toolkit['tool_code_link'], organisation_id=organisation.id,
show_toolkit=toolkit['show_toolkit'])
for tool in toolkit['tools']:
Tool.add_or_update(session=db.session, tool_name=tool['name'], description=tool['description'],
folder_name=tool['folder_name'], class_name=tool['class_name'], file_name=tool['file_name'],
toolkit_id=db_toolkit.id)
for config in toolkit['configs']:
ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'], key_type = config['key_type'], is_secret = config['is_secret'], is_required = config['is_required'])
return {"message": "ToolKit installed successfully"}
@router.get("/get/toolkit_name/{toolkit_name}")
def get_installed_toolkit_details(toolkit_name: str,
organisation: Organisation = Depends(get_user_organisation)):
"""
Get details of a locally installed tool kit by its name, including the details of its tools.
Args:
toolkit_name (str): The name of the tool kit.
organisation (Organisation): The user's organisation.
Returns:
Toolkit: The tool kit object with its associated tools.
Raises:
HTTPException (status_code=404): If the specified tool kit is not found.
"""
# Fetch the tool kit by its ID
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Organisation.id == organisation.id).first()
if not toolkit:
# Return an appropriate response if the tool kit doesn't exist
raise HTTPException(status_code=404, detail='ToolKit not found')
# Fetch the tools associated with the tool kit
tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
# Add the tools to the tool kit object
toolkit.tools = tools
# readme_content = get_readme(toolkit.tool_code_link)
return toolkit
@router.post("/get/local/install", status_code=200)
def download_and_install_tool(github_link_request: GitHubLinkRequest = Body(...),
organisation: Organisation = Depends(get_user_organisation)):
"""
Install a tool locally from a GitHub link.
Args:
github_link_request (GitHubLinkRequest): The GitHub link request object.
organisation (Organisation): The user's organisation.
Returns:
None
Raises:
HTTPException (status_code=400): If the GitHub link is invalid.
"""
github_link = github_link_request.github_link
if not GithubHelper.validate_github_link(github_link):
raise HTTPException(status_code=400, detail="Invalid Github link")
# download_folder = get_config("TOOLS_DIR")
# download_tool(github_link, download_folder)
# process_files(download_folder, db.session, organisation, code_link=github_link)
add_tool_to_json(github_link)
@router.get("/get/readme/{toolkit_name}")
def get_installed_toolkit_readme(toolkit_name: str, organisation: Organisation = Depends(get_user_organisation)):
"""
Get the readme content of a toolkit.
Args:
toolkit_name (str): The name of the toolkit.
organisation (Organisation): The user's organisation.
Returns:
str: The readme content of the toolkit.
Raises:
HTTPException (status_code=404): If the toolkit is not found.
"""
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Organisation.id == organisation.id).first()
if not toolkit:
raise HTTPException(status_code=404, detail='ToolKit not found')
readme_content = get_readme_content_from_code_link(toolkit.tool_code_link)
return readme_content
# Following APIs will be used to get marketplace related information
@router.get("/get")
def handle_marketplace_operations(
search_str: str = Query(None, title="Search String"),
toolkit_name: str = Query(None, title="Tool Kit Name")
):
"""
Handle marketplace operations.
Args:
search_str (str, optional): The search string to filter toolkits. Defaults to None.
toolkit_name (str, optional): The name of the toolkit. Defaults to None.
Returns:
dict: The response containing the marketplace details.
"""
response = Toolkit.fetch_marketplace_detail(search_str, toolkit_name)
return response
@router.get("/get/list")
def handle_marketplace_operations_list(
page: int = Query(None, title="Page Number"),
organisation: Organisation = Depends(get_user_organisation)
):
"""
Handle marketplace operation list.
Args:
page (int, optional): The page number for pagination. Defaults to None.
Returns:
dict: The response containing the marketplace list.
"""
marketplace_toolkits = Toolkit.fetch_marketplace_list(page=page)
marketplace_toolkits_with_install = Toolkit.get_toolkit_installed_details(db.session, marketplace_toolkits,
organisation)
return marketplace_toolkits_with_install
@router.get("/get/local/list")
def get_installed_toolkit_list(organisation: Organisation = Depends(get_user_organisation)):
"""
Get the list of installed tool kits.
Args:
organisation (Organisation): The organisation associated with the tool kits.
Returns:
list: The list of installed tool kits.
"""
toolkits = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
for toolkit in toolkits:
toolkit_tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
toolkit.tools = toolkit_tools
return toolkits
@router.get("/check_update/{toolkit_name}")
def check_toolkit_update(toolkit_name: str, organisation: Organisation = Depends(get_user_organisation)):
"""
Check if there is an update available for the installed tool kits.
Returns:
dict: The response containing the update details.
"""
marketplace_toolkit = Toolkit.fetch_marketplace_detail(search_str="details",
toolkit_name=toolkit_name)
if marketplace_toolkit is None:
raise HTTPException(status_code=404, detail="Toolkit not found in marketplace")
installed_toolkit = Toolkit.get_toolkit_from_name(db.session, toolkit_name, organisation)
if installed_toolkit is None:
return True
installed_toolkit = installed_toolkit.to_dict()
tools = Tool.get_toolkit_tools(db.session, installed_toolkit["id"])
configs = ToolConfig.get_toolkit_tool_config(db.session, installed_toolkit["id"])
installed_toolkit["configs"] = []
installed_toolkit["tools"] = []
for config in configs:
installed_toolkit["configs"].append(config.to_dict())
for tool in tools:
installed_toolkit["tools"].append(tool.to_dict())
return compare_toolkit(marketplace_toolkit, installed_toolkit)
@router.put("/update/{toolkit_name}")
def update_toolkit(toolkit_name: str, organisation: Organisation = Depends(get_user_organisation)):
"""
Update the toolkit with the latest version from the marketplace.
"""
marketplace_toolkit = Toolkit.fetch_marketplace_detail(search_str="details",
toolkit_name=toolkit_name)
update_toolkit = Toolkit.add_or_update(
db.session,
name=marketplace_toolkit["name"],
description=marketplace_toolkit["description"],
show_toolkit=True if len(marketplace_toolkit["tools"]) > 1 else False,
organisation_id=organisation.id,
tool_code_link=marketplace_toolkit["tool_code_link"]
)
for tool in marketplace_toolkit["tools"]:
Tool.add_or_update(db.session, tool_name=tool["name"], folder_name=tool["folder_name"],
class_name=tool["class_name"], file_name=tool["file_name"],
toolkit_id=update_toolkit.id, description=tool["description"])
for tool_config_key in marketplace_toolkit["configs"]:
ToolConfig.add_or_update(db.session, toolkit_id=update_toolkit.id, key=tool_config_key["key"], key_type = tool_config_key['key_type'], is_secret = tool_config_key['is_secret'], is_required = tool_config_key['is_required'])
================================================
FILE: superagi/controllers/twitter_oauth.py
================================================
import http.client as http_client
import json
from fastapi import APIRouter
from fastapi import Depends, Query
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
import superagi
from superagi.helper.auth import get_current_user, check_auth
from superagi.helper.twitter_tokens import TwitterTokens
from superagi.models.oauth_tokens import OauthTokens
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
router = APIRouter()
@router.get('/oauth-tokens')
async def twitter_oauth(oauth_token: str = Query(...),oauth_verifier: str = Query(...), Authorize: AuthJWT = Depends()):
token_uri = f'https://api.twitter.com/oauth/access_token?oauth_verifier={oauth_verifier}&oauth_token={oauth_token}'
conn = http_client.HTTPSConnection("api.twitter.com")
conn.request("POST", token_uri, "")
res = conn.getresponse()
response_data = res.read().decode('utf-8')
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
redirect_url_success = f"{frontend_url}/twitter_creds/?{response_data}"
return RedirectResponse(url=redirect_url_success)
@router.post("/send_twitter_creds/{twitter_creds}")
def send_twitter_tool_configs(twitter_creds: str, Authorize: AuthJWT = Depends(check_auth)):
current_user = get_current_user(Authorize)
user_id = current_user.id
credentials = json.loads(twitter_creds)
credentials["user_id"] = user_id
toolkit = db.session.query(Toolkit).filter(Toolkit.id == credentials["toolkit_id"]).first()
api_key = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_KEY", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
if is_encrypted(api_key.value):
api_key.value = decrypt_data(api_key.value)
api_key_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_SECRET", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
if is_encrypted(api_key_secret.value):
api_key_secret.value = decrypt_data(api_key_secret.value)
final_creds = {
"api_key": api_key.value,
"api_key_secret": api_key_secret.value,
"oauth_token": credentials["oauth_token"],
"oauth_token_secret": credentials["oauth_token_secret"]
}
tokens = OauthTokens().add_or_update(db.session, credentials["toolkit_id"], user_id, toolkit.organisation_id, "TWITTER_OAUTH_TOKENS", str(final_creds))
if tokens:
success = True
else:
success = False
return success
@router.get("/get_twitter_creds/toolkit_id/{toolkit_id}")
def get_twitter_tool_configs(toolkit_id: int):
twitter_config_key = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_KEY").first()
twitter_config_secret = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_SECRET").first()
if is_encrypted(twitter_config_key.value):
twitter_config_key.value = decrypt_data(twitter_config_key.value)
if is_encrypted(twitter_config_secret.value):
twitter_config_secret.value = decrypt_data(twitter_config_secret.value)
api_data = {
"api_key": twitter_config_key.value,
"api_secret": twitter_config_secret.value
}
response = TwitterTokens(db.session).get_request_token(api_data)
return response
================================================
FILE: superagi/controllers/types/agent_execution_config.py
================================================
import datetime
from typing import List, Optional
from pydantic import BaseModel
from datetime import datetime
class AgentRunIn(BaseModel):
status: Optional[str]
name: Optional[str]
agent_id: Optional[int]
last_execution_time: Optional[datetime]
num_of_calls: Optional[int]
num_of_tokens: Optional[int]
current_step_id: Optional[int]
permission_id: Optional[int]
goal: Optional[List[str]]
instruction: Optional[List[str]]
agent_workflow: str
constraints: List[str]
toolkits: List[int]
tools: List[int]
exit: str
iteration_interval: int
model: str
permission_type: str
LTM_DB: str
max_iterations: int
user_timezone: Optional[str]
knowledge: Optional[int]
class Config:
orm_mode = True
================================================
FILE: superagi/controllers/types/agent_publish_config.py
================================================
from typing import List, Optional
from pydantic import BaseModel
class AgentPublish(BaseModel):
name: str
description: str
agent_template_id: int
goal: Optional[List[str]]
instruction: Optional[List[str]]
constraints: List[str]
toolkits: List[int]
tools: List[int]
exit: str
iteration_interval: int
model: str
permission_type: str
LTM_DB: str
max_iterations: int
user_timezone: Optional[str]
knowledge: Optional[int]
class Config:
orm_mode = True
================================================
FILE: superagi/controllers/types/agent_schedule.py
================================================
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
class AgentScheduleInput(BaseModel):
agent_id: Optional[int]
start_time: datetime
recurrence_interval: Optional[str] = None
expiry_date: Optional[datetime] = None
expiry_runs: Optional[int] = -1
================================================
FILE: superagi/controllers/types/agent_with_config.py
================================================
from pydantic import BaseModel
from typing import List, Optional
from superagi.controllers.types.agent_schedule import AgentScheduleInput
class AgentConfigInput(BaseModel):
name: str
project_id: int
description: str
goal: List[str]
instruction: List[str]
agent_workflow: str
constraints: List[str]
toolkits: List[int]
tools: List[int]
exit: str
iteration_interval: int
model: str
permission_type: str
LTM_DB: str
max_iterations: int
user_timezone: Optional[str]
knowledge: Optional[int]
class AgentConfigExtInput(BaseModel):
name: str
description: str
project_id: Optional[int]
goal: List[str]
instruction: List[str]
agent_workflow: str
constraints: List[str]
tools: List[dict]
LTM_DB:Optional[str]
exit: Optional[str]
permission_type: Optional[str]
iteration_interval: int
model: str
schedule: Optional[AgentScheduleInput]
max_iterations: int
user_timezone: Optional[str]
knowledge: Optional[int]
class AgentConfigUpdateExtInput(BaseModel):
name: Optional[str]
description: Optional[str]
project_id: Optional[int]
goal: Optional[List[str]]
instruction: Optional[List[str]]
agent_workflow: Optional[str]
constraints: Optional[List[str]]
tools: Optional[List[dict]]
LTM_DB:Optional[str]
exit: Optional[str]
permission_type: Optional[str]
iteration_interval: Optional[int]
model: Optional[str]
schedule: Optional[AgentScheduleInput]
max_iterations: Optional[int]
user_timezone: Optional[str]
knowledge: Optional[int]
================================================
FILE: superagi/controllers/types/agent_with_config_schedule.py
================================================
from pydantic import BaseModel
from superagi.controllers.types.agent_schedule import AgentScheduleInput
from superagi.controllers.types.agent_with_config import AgentConfigInput
class AgentConfigSchedule(BaseModel):
agent_config: AgentConfigInput
schedule: AgentScheduleInput
================================================
FILE: superagi/controllers/types/models_types.py
================================================
from enum import Enum
class ModelsTypes(Enum):
MARKETPLACE = "Marketplace"
CUSTOM = "Custom"
@classmethod
def get_models_types(cls, model_type):
if model_type is None:
raise ValueError("Queue status type cannot be None.")
model_type = model_type.upper()
if model_type in cls.__members__:
return cls[model_type]
raise ValueError(f"{model_type} is not a valid storage name.")
================================================
FILE: superagi/controllers/user.py
================================================
from datetime import datetime
from typing import Optional
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Request
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.user import User
from fastapi import APIRouter
from superagi.helper.auth import check_auth, get_current_user
from superagi.lib.logger import logger
from superagi.models.models_config import ModelsConfig
# from superagi.types.db import UserBase, UserIn, UserOut
router = APIRouter()
class UserBase(BaseModel):
name: str
email: str
password: str
class Config:
orm_mode = True
class UserOut(UserBase):
id: int
organisation_id: int
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class UserIn(UserBase):
organisation_id: Optional[int]
class Config:
orm_mode = True
# CRUD Operations
@router.post("/add", response_model=UserOut, status_code=201)
def create_user(user: UserIn, Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new user.
Args:
user (UserIn): User data.
Returns:
User: The created user.
Raises:
HTTPException (status_code=400): If there is an issue creating the user.
HTTPException (status_code=422): If required fields are missing or incorrectly formatted.
"""
logger.info("Received user data: %s", user)
# Validate incoming request data
if not user.name or not user.email or not user.password:
logger.error("Missing required fields: name, email, or password")
raise HTTPException(status_code=422, detail="Missing required fields: name, email, or password")
db_user = db.session.query(User).filter(User.email == user.email).first()
if db_user:
return db_user
db_user = User(name=user.name, email=user.email, password=user.password, organisation_id=user.organisation_id)
db.session.add(db_user)
db.session.commit()
db.session.flush()
organisation = Organisation.find_or_create_organisation(db.session, db_user)
Project.find_or_create_default_project(db.session, organisation.id)
logger.info("User created: %s", db_user)
# Adding local LLM configuration
ModelsConfig.add_llm_config(db.session, organisation.id)
return db_user
@router.get("/get/{user_id}", response_model=UserOut)
def get_user(user_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Get a particular user details.
Args:
user_id (int): ID of the user.
Returns:
User: The user details.
Raises:
HTTPException (status_code=404): If the user with the specified ID is not found.
"""
# Authorize.jwt_required()
db_user = db.session.query(User).filter(User.id == user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@router.put("/update/{user_id}", response_model=UserOut)
def update_user(user_id: int,
user: UserBase,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a particular user.
Args:
user_id (int): ID of the user.
user (UserIn): Updated user data.
Returns:
User: The updated user details.
Raises:
HTTPException (status_code=404): If the user with the specified ID is not found.
"""
db_user = db.session.query(User).filter(User.id == user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
db_user.name = user.name
db_user.email = user.email
db_user.password = user.password
db.session.commit()
return db_user
@router.post("/first_login_source/{source}")
def update_first_login_source(source: str, Authorize: AuthJWT = Depends(check_auth)):
""" Update first login source of the user """
user = get_current_user(Authorize)
# valid_sources = ['google', 'github', 'email']
if user.first_login_source is None or user.first_login_source == '':
user.first_login_source = source
db.session.commit()
db.session.flush()
logger.info("User : ",user)
return user
================================================
FILE: superagi/controllers/vector_db_indices.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Query
from fastapi import APIRouter
from superagi.helper.auth import get_user_organisation
from superagi.models.vector_dbs import Vectordbs
from superagi.models.vector_db_indices import VectordbIndices
from superagi.models.knowledges import Knowledges
from superagi.models.knowledge_configs import KnowledgeConfigs
router = APIRouter()
@router.get("/marketplace/valid_indices/{knowledge_name}")
def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(get_user_organisation)):
vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
knowledge = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
knowledge_with_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge['id'])
pinecone = []
qdrant = []
weaviate = []
for vector_db in vector_dbs:
indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
for index in indices:
data = {"id": index.id, "name": index.name}
data["is_valid_dimension"] = True if index.dimensions == int(knowledge_with_config["dimensions"]) else False
data["is_valid_state"] = True if index.state != "Custom" else False
if vector_db.db_type == "Pinecone":
pinecone.append(data)
if vector_db.db_type == "Qdrant":
qdrant.append(data)
if vector_db.db_type == "Weaviate":
data["is_valid_dimension"] = True
weaviate.append(data)
return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate}
@router.get("/user/valid_indices")
def get_user_valid_indices(organisation = Depends(get_user_organisation)):
vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
pinecone = []
qdrant = []
weaviate = []
for vector_db in vector_dbs:
indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
for index in indices:
data = {"id": index.id, "name": index.name}
data["is_valid_state"] = True if index.state == "Custom" else False
if vector_db.db_type == "Pinecone":
pinecone.append(data)
if vector_db.db_type == "Qdrant":
qdrant.append(data)
if vector_db.db_type == "Weaviate":
weaviate.append(data)
return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate}
================================================
FILE: superagi/controllers/vector_dbs.py
================================================
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends
from fastapi import APIRouter
from superagi.config.config import get_config
from datetime import datetime
from superagi.helper.time_helper import get_time_difference
from superagi.models.vector_dbs import Vectordbs
from superagi.helper.auth import get_user_organisation
from superagi.models.vector_db_configs import VectordbConfigs
from superagi.models.vector_db_indices import VectordbIndices
from superagi.vector_store.vector_factory import VectorFactory
from superagi.models.knowledges import Knowledges
router = APIRouter()
@router.get("/get/list")
def get_vector_db_list():
marketplace_vector_dbs = Vectordbs.fetch_marketplace_list()
return marketplace_vector_dbs
@router.get("/marketplace/list")
def get_marketplace_vectordb_list():
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
vector_dbs = db.session.query(Vectordbs).filter(Vectordbs.organisation_id == organisation_id).all()
return vector_dbs
@router.get("/user/list")
def get_user_connected_vector_db_list(organisation = Depends(get_user_organisation)):
vector_db_list = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
if vector_db_list:
for vector in vector_db_list:
vector.updated_at = get_time_difference(vector.updated_at, str(datetime.now()))
return vector_db_list
@router.get("/db/details/{vector_db_id}")
def get_vector_db_details(vector_db_id: int):
vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id)
vector_db_data = {
"id": vector_db.id,
"name": vector_db.name,
"db_type": vector_db.db_type
}
vector_db_config = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector_db_id)
vector_db_with_config = vector_db_data | vector_db_config
indices = db.session.query(VectordbIndices).filter(VectordbIndices.vector_db_id == vector_db_id).all()
vector_indices = []
for index in indices:
vector_indices.append(index.name)
vector_db_with_config["indices"] = vector_indices
return vector_db_with_config
@router.post("/delete/{vector_db_id}")
def delete_vector_db(vector_db_id: int):
try:
vector_indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db_id)
for vector_index in vector_indices:
Knowledges.delete_knowledge_from_vector_index(db.session, vector_index.id)
VectordbIndices.delete_vector_db_index(db.session, vector_index.id)
VectordbConfigs.delete_vector_db_configs(db.session, vector_db_id)
Vectordbs.delete_vector_db(db.session, vector_db_id)
except:
raise HTTPException(status_code=404, detail="VectorDb not found")
@router.post("/connect/pinecone")
def connect_pinecone_vector_db(data: dict, organisation = Depends(get_user_organisation)):
db_creds = {
"api_key": data["api_key"],
"environment": data["environment"]
}
for collection in data["collections"]:
try:
vector_db_storage = VectorFactory.build_vector_storage("pinecone", collection, **db_creds)
db_connect_for_index = vector_db_storage.get_index_stats()
index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
except:
raise HTTPException(status_code=400, detail="Unable to connect Pinecone")
pinecone_db = Vectordbs.add_vector_db(db.session, data["name"], "Pinecone", organisation)
VectordbConfigs.add_vector_db_config(db.session, pinecone_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, index_state, db_connect_for_index["dimensions"])
return {"id": pinecone_db.id, "name": pinecone_db.name}
@router.post("/connect/qdrant")
def connect_qdrant_vector_db(data: dict, organisation = Depends(get_user_organisation)):
db_creds = {
"api_key": data["api_key"],
"url": data["url"],
"port": data["port"]
}
for collection in data["collections"]:
try:
vector_db_storage = VectorFactory.build_vector_storage("qdrant", collection, **db_creds)
db_connect_for_index = vector_db_storage.get_index_stats()
index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
except:
raise HTTPException(status_code=400, detail="Unable to connect Qdrant")
qdrant_db = Vectordbs.add_vector_db(db.session, data["name"], "Qdrant", organisation)
VectordbConfigs.add_vector_db_config(db.session, qdrant_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, index_state, db_connect_for_index["dimensions"])
return {"id": qdrant_db.id, "name": qdrant_db.name}
@router.post("/connect/weaviate")
def connect_weaviate_vector_db(data: dict, organisation = Depends(get_user_organisation)):
db_creds = {
"api_key": data["api_key"],
"url": data["url"]
}
for collection in data["collections"]:
try:
vector_db_storage = VectorFactory.build_vector_storage("weaviate", collection, **db_creds)
db_connect_for_index = vector_db_storage.get_index_stats()
index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
except:
raise HTTPException(status_code=400, detail="Unable to connect Weaviate")
weaviate_db = Vectordbs.add_vector_db(db.session, data["name"], "Weaviate", organisation)
VectordbConfigs.add_vector_db_config(db.session, weaviate_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, weaviate_db.id, index_state)
return {"id": weaviate_db.id, "name": weaviate_db.name}
@router.put("/update/vector_db/{vector_db_id}")
def update_vector_db(new_indices: list, vector_db_id: int):
vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id)
existing_indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db_id)
existing_index_names = []
for index in existing_indices:
if index.name not in new_indices:
VectordbIndices.delete_vector_db_index(db.session, vector_index_id=index.id)
existing_index_names.append(index.name)
existing_index_names = set(existing_index_names)
new_indices_names = set(new_indices)
added_indices = new_indices_names - existing_index_names
for index in added_indices:
db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector_db_id)
try:
vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, index, **db_creds)
vector_db_index_stats = vector_db_storage.get_index_stats()
index_state = "Custom" if vector_db_index_stats["vector_count"] > 0 else "None"
dimensions = vector_db_index_stats["dimensions"] if 'dimensions' in vector_db_index_stats else None
except:
raise HTTPException(status_code=400, detail="Unable to update vector db")
VectordbIndices.add_vector_index(db.session, index, vector_db_id, index_state, dimensions)
================================================
FILE: superagi/controllers/webhook.py
================================================
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from fastapi import Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
# from superagi.types.db import AgentOut, AgentIn
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.models.webhooks import Webhooks
router = APIRouter()
class WebHookIn(BaseModel):
name: str
url: str
headers: dict
filters: dict
class Config:
orm_mode = True
class WebHookOut(BaseModel):
id: int
org_id: int
name: str
url: str
headers: dict
is_deleted: bool
created_at: datetime
updated_at: datetime
filters: dict
class Config:
orm_mode = True
class WebHookEdit(BaseModel):
url: str
filters: dict
class Config:
orm_mode = True
# CRUD Operations`
@router.post("/add", response_model=WebHookOut, status_code=201)
def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth),
organisation=Depends(get_user_organisation)):
"""
Creates a new webhook
Args:
Returns:
Agent: An object of Agent representing the created Agent.
Raises:
HTTPException (Status Code=404): If the associated project is not found.
"""
db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id,
is_deleted=False, filters=webhook.filters)
db.session.add(db_webhook)
db.session.commit()
db.session.flush()
return db_webhook
@router.get("/get", response_model=Optional[WebHookOut])
def get_all_webhooks(
Authorize: AuthJWT = Depends(check_auth),
organisation=Depends(get_user_organisation),
):
"""
Retrieves a single webhook for the authenticated user's organisation.
Returns:
JSONResponse: A JSON response containing the retrieved webhook.
Raises:
"""
webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.is_deleted == False).first()
return webhook
@router.post("/edit/{webhook_id}", response_model=WebHookOut)
def edit_webhook(
updated_webhook: WebHookEdit,
webhook_id: int,
Authorize: AuthJWT = Depends(check_auth),
organisation=Depends(get_user_organisation),
):
"""
Soft-deletes a webhook by setting the value of is_deleted to True.
Args:
webhook_id (int): The ID of the webhook to delete.
Returns:
WebHookOut: The deleted webhook.
Raises:
HTTPException (Status Code=404): If the webhook is not found.
"""
webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.id == webhook_id, Webhooks.is_deleted == False).first()
if webhook is None:
raise HTTPException(status_code=404, detail="Webhook not found")
webhook.url = updated_webhook.url
webhook.filters = updated_webhook.filters
db.session.commit()
return webhook
================================================
FILE: superagi/helper/agent_schedule_helper.py
================================================
from superagi.models.db import connect_db
from sqlalchemy.orm import sessionmaker
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_schedule import AgentSchedule
from datetime import datetime, timedelta
from superagi.helper.time_helper import parse_interval_to_seconds
import pytz
engine = connect_db()
Session = sessionmaker(bind=engine)
class AgentScheduleHelper:
AGENT_SCHEDULE_TIME_INTERVAL = 300
def run_scheduled_agents(self):
"""
Execute all eligible scheduled agent tasks since last five minutes.
"""
now = datetime.now()
last_five_minutes = now - timedelta(minutes=5)
session = Session()
scheduled_agents = session.query(AgentSchedule).filter(
AgentSchedule.next_scheduled_time.between(last_five_minutes, now), AgentSchedule.status == "SCHEDULED").all()
for agent in scheduled_agents:
interval = agent.recurrence_interval
interval_in_seconds = 0 # default value
if interval is not None:
interval_in_seconds = parse_interval_to_seconds(interval)
agent_id = agent.agent_id
agent_execution_name = self.__create_execution_name_for_scheduling(agent_id)
should_execute_agent = self.__should_execute_agent(agent, interval)
self.__execute_schedule(should_execute_agent, interval_in_seconds, session, agent,
agent_execution_name)
for agent in scheduled_agents:
if self.__can_remove_agent(agent, interval):
agent.status = "COMPLETED"
session.commit()
session.close()
def update_next_scheduled_time(self):
"""
Update the next scheduled time of each agent and terminate those who have finished their schedule, in case of any miss.
"""
now = datetime.now()
session = Session()
scheduled_agents = session.query(AgentSchedule).filter(
AgentSchedule.start_time <= now, AgentSchedule.next_scheduled_time <= now,
AgentSchedule.status == "SCHEDULED").all()
for agent in scheduled_agents:
if (now - agent.next_scheduled_time).total_seconds() < AgentScheduleHelper.AGENT_SCHEDULE_TIME_INTERVAL:
continue
if agent.recurrence_interval is not None:
interval_in_seconds = parse_interval_to_seconds(agent.recurrence_interval)
time_diff = now - agent.start_time
num_intervals_passed = time_diff.total_seconds() // interval_in_seconds
updated_next_scheduled_time = agent.start_time + timedelta(
seconds=(interval_in_seconds * (num_intervals_passed + 1)))
agent.next_scheduled_time = updated_next_scheduled_time
else:
agent.status = "TERMINATED"
session.commit()
session.close()
def __create_execution_name_for_scheduling(self, agent_id) -> str:
"""
Create name for an agent execution based on current time.
Args:
agent_id (str): The id of the agent job to be scheduled.
Returns:
str: Execution name of the agent in the format "Run "
"""
session = Session()
user_timezone = session.query(AgentConfiguration).filter(AgentConfiguration.key == "user_timezone",
AgentConfiguration.agent_id == agent_id).first()
if user_timezone and user_timezone.value != "None":
current_time = datetime.now().astimezone(pytz.timezone(user_timezone.value))
else:
current_time = datetime.now().astimezone(pytz.timezone('GMT'))
timestamp = current_time.strftime(" %d %B %Y %H:%M")
return f"Run{timestamp}"
def __should_execute_agent(self, agent, interval):
"""
Determine if an agent should be executed based on its scheduling.
Args:
agent (object): The agent job to evaluate.
interval (int): Recurrence interval of the scheduled agent in seconds.
Returns:
bool: True if the agent should be executed, False otherwise.
"""
expiry_date = agent.expiry_date
expiry_runs = agent.expiry_runs
current_runs = agent.current_runs
# If there's no interval or there are no restrictions on when or how many times an agent can run
if not interval or (expiry_date is None and expiry_runs == -1):
return True
# Check if the agent's expiry date has not passed yet
if expiry_date and datetime.now() < expiry_date:
return True
# Check if the agent has not yet run as many times as allowed
if expiry_runs != -1 and current_runs < expiry_runs:
return True
# If none of the conditions to run the agent is met, return False (i.e., do not run the agent)
return False
def __can_remove_agent(self, agent, interval):
"""
Determine if an agent can be removed based on its scheduled expiry.
Args:
agent (object): The agent job to evaluate.
interval (int): Recurrence interval of the scheduled agent in seconds.
Returns:
bool: True if the agent can be removed, False otherwise.
"""
expiry_date = agent.expiry_date
expiry_runs = agent.expiry_runs
current_runs = agent.current_runs
# Calculate the next scheduled time only if an interval exists.
next_scheduled = agent.next_scheduled_time + timedelta(seconds=parse_interval_to_seconds(interval)) if interval else None
# If there's no interval, the agent can be removed
if not interval:
return True
# If the agent's expiry date has not come yet and next schedule is before expiry date, it cannot be removed
if expiry_date and datetime.now() < expiry_date and (next_scheduled is None or next_scheduled <= expiry_date):
return False
# If agent has not yet run as many times as allowed, it cannot be removed
if expiry_runs != -1 and current_runs < expiry_runs:
return False
# If there are no restrictions on when or how many times an agent can run, it cannot be removed
if expiry_date is None and expiry_runs == -1:
return False
# If none of the conditions to keep the agent is met, we return True (i.e., the agent can be removed)
return True
def __execute_schedule(self, should_execute_agent, interval_in_seconds, session, agent, agent_execution_name):
"""
Executes a scheduled job, if it should be executed.
Args:
should_execute_agent (bool): Whether agent should be executed.
interval_in_seconds (int): The interval in seconds for the schedule.
session (Session): The database session.
agent (object): The agent to be scheduled.
agent_execution_name (str): The name for the execution.
"""
from superagi.jobs.scheduling_executor import ScheduledAgentExecutor
if should_execute_agent:
executor = ScheduledAgentExecutor()
executor.execute_scheduled_agent(agent.agent_id, agent_execution_name)
agent.current_runs = agent.current_runs + 1
if agent.recurrence_interval:
next_scheduled_time = agent.next_scheduled_time + timedelta(seconds=interval_in_seconds)
agent.next_scheduled_time = next_scheduled_time
session.commit()
================================================
FILE: superagi/helper/auth.py
================================================
from fastapi import Depends, HTTPException, Header, Security, status
from fastapi.security import APIKeyHeader
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from fastapi.security.api_key import APIKeyHeader
from superagi.config.config import get_config
from superagi.models.organisation import Organisation
from superagi.models.user import User
from superagi.models.api_key import ApiKey
from typing import Optional
from sqlalchemy import or_
def check_auth(Authorize: AuthJWT = Depends()):
"""
Function to check if the user is authenticated or not based on the environment.
Args:
Authorize (AuthJWT, optional): Instance of AuthJWT class to authorize the user. Defaults to Depends().
Returns:
AuthJWT: Instance of AuthJWT class if the user is authenticated.
"""
env = get_config("ENV", "DEV")
if env == "PROD":
Authorize.jwt_required()
return Authorize
def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
"""
Function to get the organisation of the authenticated user based on the environment.
Args:
Authorize (AuthJWT, optional): Instance of AuthJWT class to authorize the user. Defaults to Depends on check_auth().
Returns:
Organisation: Instance of Organisation class to which the authenticated user belongs.
"""
user = get_current_user(Authorize)
if user is None:
raise HTTPException(status_code=401, detail="Unauthenticated")
organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
return organisation
def get_current_user(Authorize: AuthJWT = Depends(check_auth), request: Request = Depends()):
env = get_config("ENV", "DEV")
if env == "DEV":
email = "super6@agi.com"
else:
# Check for HTTP basic auth headers
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Basic '):
import base64
auth_decoded = base64.b64decode(auth_header.split(' ')[1]).decode('utf-8')
username, password = auth_decoded.split(':')
# Assuming username is the email
email = username
else:
# Retrieve the email of the logged-in user from the JWT token payload
email = Authorize.get_jwt_subject()
# Query the User table to find the user by their email
user = db.session.query(User).filter(User.email == email).first()
return user
api_key_header = APIKeyHeader(name="X-API-Key")
def validate_api_key(api_key: str = Security(api_key_header)) -> str:
query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
if query_result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API Key",
)
return query_result.key
def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation:
query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
if query_result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API Key",
)
organisation = db.session.query(Organisation).filter(Organisation.id == query_result.org_id).first()
return organisation
================================================
FILE: superagi/helper/calendar_date.py
================================================
from datetime import datetime, timedelta, timezone
import pytz
class CalendarDate:
def create_event_dates(self, service, start_date, start_time, end_date, end_time):
local_tz = pytz.timezone(self._get_time_zone(service))
start_datetime, end_datetime = self._localize_daterange(start_date, end_date, start_time, end_time, local_tz)
date_utc = {
"start_datetime_utc": self._datetime_to_string(start_datetime, "%Y-%m-%dT%H:%M:%S.%fZ"),
"end_datetime_utc": self._datetime_to_string(end_datetime, "%Y-%m-%dT%H:%M:%S.%fZ"),
"timeZone": self._get_time_zone(service)
}
return date_utc
def get_date_utc(self, start_date, end_date, start_time, end_time, service):
local_tz = pytz.timezone(self._get_time_zone(service))
start_datetime, end_datetime = self._localize_daterange(start_date, end_date, start_time, end_time, local_tz)
date_utc = {
"start_datetime_utc": self._datetime_to_string(start_datetime, "%Y-%m-%dT%H:%M:%S.%fZ"),
"end_datetime_utc": self._datetime_to_string(end_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
}
return date_utc
def _get_time_zone(self, service):
calendar = service.calendars().get(calendarId='primary').execute()
time_detail = calendar['timeZone']
return time_detail
def _convert_to_utc(self, date_time, local_tz):
local_datetime = local_tz.localize(date_time)
gmt_tz = pytz.timezone("GMT")
return local_datetime.astimezone(gmt_tz)
def _string_to_datetime(self, date_str, date_format):
return datetime.strptime(date_str, date_format) if date_str else None
def _localize_daterange(self, start_date, end_date, start_time, end_time, local_tz):
start_datetime = self._string_to_datetime(start_date, "%Y-%m-%d") if start_date != 'None' else datetime.now(
timezone.utc)
end_datetime = self._string_to_datetime(end_date,
"%Y-%m-%d") if end_date != 'None' else start_datetime + timedelta(
days=30) - timedelta(microseconds=1)
time_obj_start = self._string_to_datetime(start_time, "%H:%M:%S")
time_obj_end = self._string_to_datetime(end_time, "%H:%M:%S")
start_datetime = start_datetime.replace(hour=time_obj_start.hour, minute=time_obj_start.minute,
second=time_obj_start.second,
microsecond=0) if time_obj_start else start_datetime.replace(hour=0,
minute=0,
second=0,
microsecond=0)
end_datetime = end_datetime.replace(hour=time_obj_end.hour, minute=time_obj_end.minute,
second=time_obj_end.second) if time_obj_end else end_datetime.replace(
hour=23, minute=59, second=59, microsecond=999999)
return self._convert_to_utc(start_datetime, local_tz), self._convert_to_utc(end_datetime, local_tz)
def _datetime_to_string(self, date_time, date_format):
return date_time.strftime(date_format) if date_time else None
================================================
FILE: superagi/helper/encyption_helper.py
================================================
import base64
from cryptography.fernet import Fernet, InvalidToken, InvalidSignature
from superagi.config.config import get_config
from superagi.lib.logger import logger
# Generate a key
# key = Fernet.generate_key()
key = get_config("ENCRYPTION_KEY")
if key is None:
raise Exception("Encryption key not found in config file.")
if len(key) != 32:
raise ValueError("Encryption key must be 32 bytes long.")
# Encode the key to UTF-8
key = key.encode(
"utf-8"
)
# base64 encode the key
key = base64.urlsafe_b64encode(key)
# Create a cipher suite
cipher_suite = Fernet(key)
def encrypt_data(data):
"""
Encrypts the given data using the Fernet cipher suite.
Args:
data (str): The data to be encrypted.
Returns:
str: The encrypted data, decoded as a string.
"""
encrypted_data = cipher_suite.encrypt(data.encode())
return encrypted_data.decode()
def decrypt_data(encrypted_data):
"""
Decrypts the given encrypted data using the Fernet cipher suite.
Args:
encrypted_data (str): The encrypted data to be decrypted.
Returns:
str: The decrypted data, decoded as a string.
"""
decrypted_data = cipher_suite.decrypt(encrypted_data.encode())
return decrypted_data.decode()
def is_encrypted(value):
#key = get_config("ENCRYPTION_KEY")
try:
f = Fernet(key)
f.decrypt(value)
return True
except (InvalidToken, InvalidSignature):
return False
except (ValueError, TypeError):
return False
================================================
FILE: superagi/helper/error_handler.py
================================================
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
class ErrorHandler:
def handle_openai_errors(session, agent_id, agent_execution_id, error_message):
execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
agent_feed = AgentExecutionFeed(agent_execution_id=agent_execution_id, agent_id=agent_id, role="system", feed="", error_message=error_message, feed_group_id=execution.current_feed_group_id)
session.add(agent_feed)
session.commit()
================================================
FILE: superagi/helper/feed_parser.py
================================================
import json
from datetime import datetime
from superagi.helper.time_helper import get_time_difference
from superagi.lib.logger import logger
def parse_feed(feed):
"""
Helper function to parse the feed.
Args:
feed (AgentExecutionFeed): The feed to be parsed.
Returns:
dict: Parsed feed information with role, feed content, and updated timestamp.
If parsing fails, the original feed is returned.
"""
# Get the current time
feed.time_difference = get_time_difference(feed.updated_at, str(datetime.now()))
# Check if the feed belongs to an assistant role
if feed.role == "assistant":
try:
# Parse the feed as JSON
parsed = json.loads(feed.feed, strict=False)
final_output = ""
if "reasoning" in parsed["thoughts"]:
final_output = "Thoughts: " + parsed["thoughts"]["reasoning"] + "\n"
if "plan" in parsed["thoughts"]:
final_output += "Plan: " + str(parsed["thoughts"]["plan"]) + "\n"
if "criticism" in parsed["thoughts"]:
final_output += "Criticism: " + parsed["thoughts"]["criticism"] + "\n"
if "tool" in parsed:
final_output += "Tool: " + parsed["tool"]["name"] + "\n"
if "command" in parsed:
final_output += "Tool: " + parsed["command"]["name"] + "\n"
return {"role": "assistant", "feed": final_output, "updated_at": feed.updated_at,
"time_difference": feed.time_difference}
except Exception:
return {"role": "assistant", "feed": feed.feed, "updated_at": feed.updated_at,
"time_difference": feed.time_difference}
if feed.role == "system":
final_output = feed.feed
if "json-schema.org" in feed.feed:
final_output = feed.feed.split("TOOLS:")[0]
return {"role": "system", "feed": final_output, "updated_at": feed.updated_at,
"time_difference": feed.time_difference}
if feed.role == "user":
return {"role": "user", "feed": feed.feed, "updated_at": feed.updated_at,
"time_difference": feed.time_difference}
return feed
================================================
FILE: superagi/helper/github_helper.py
================================================
import base64
import re
import requests
from superagi.lib.logger import logger
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
from datetime import timedelta, datetime
import json
class GithubHelper:
def __init__(self, github_access_token, github_username):
"""
Initializes the GithubHelper with the provided access token and username.
Args:
github_access_token (str): Personal GitHub access token.
github_username (str): GitHub username.
"""
self.github_access_token = github_access_token
self.github_username = github_username
def get_file_path(self, file_name, folder_path):
"""
Returns the path of the given file with respect to the specified folder.
Args:
file_name (str): Name of the file.
folder_path (str): Path to the folder.
Returns:
str: Combined file path.
"""
file_path = f'{folder_path}'
if folder_path:
file_path += '/'
file_path += file_name
return file_path
def check_repository_visibility(self, repository_owner, repository_name):
"""
Checks the visibility (public/private) of a given repository.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
Returns:
bool: True if the repository is private, False if it's public.
"""
url = f"https://api.github.com/repos/{repository_owner}/{repository_name}"
headers = {
"Authorization": f"Token {self.github_access_token}",
"Accept": "application/vnd.github.v3+json"
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
repository_data = response.json()
return repository_data['private']
else:
logger.info(f"Failed to fetch repository information: {response.status_code} - {response.text}")
return None
def search_repo(self, repository_owner, repository_name, file_name, folder_path=None):
"""
Searches for a file in the given repository and returns the file's metadata.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
file_name (str): Name of the file to search for.
folder_path (str, optional): Path to the folder containing the file. Defaults to None.
Returns:
dict: File metadata.
"""
headers = {
"Authorization": f"token {self.github_access_token}" if self.github_access_token else None,
"Content-Type": "application/vnd.github+json"
}
file_path = self.get_file_path(file_name, folder_path)
url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/contents/{file_path}'
r = requests.get(url, headers=headers)
r.raise_for_status()
data = r.json()
return data
def sync_branch(self, repository_owner, repository_name, base_branch, head_branch, headers):
"""
Syncs the head branch with the base branch.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
base_branch (str): Base branch to sync with.
head_branch (str): Head branch to sync.
headers (dict): Request headers.
Returns:
None
"""
base_branch_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/branches/{base_branch}'
response = requests.get(base_branch_url, headers=headers)
response_json = response.json()
base_commit_sha = response_json['commit']['sha']
head_branch_url = f'https://api.github.com/repos/{self.github_username}/{repository_name}/git/refs/heads/{head_branch}'
data = {
'sha': base_commit_sha,
'force': True
}
response = requests.patch(head_branch_url, json=data, headers=headers)
if response.status_code == 200:
logger.info(
f'Successfully synced {self.github_username}:{head_branch} branch with {repository_owner}:{base_branch}')
else:
logger.info('Failed to sync the branch. Check your inputs and permissions.')
def make_fork(self, repository_owner, repository_name, base_branch, headers):
"""
Creates a fork of the given repository.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
base_branch (str): Base branch to sync with.
headers (dict): Request headers.
Returns:
int: Status code of the fork request.
"""
fork_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/forks'
fork_response = requests.post(fork_url, headers=headers)
if fork_response.status_code == 202:
logger.info('Fork created successfully.')
self.sync_branch(repository_owner, repository_name, base_branch, base_branch, headers)
else:
logger.info('Failed to create the fork:', fork_response.json()['message'])
return fork_response.status_code
def create_branch(self, repository_name, base_branch, head_branch, headers):
"""
Creates a new branch in the given repository.
Args:
repository_name (str): Name of the repository.
base_branch (str): Base branch to sync with.
head_branch (str): Head branch to sync.
headers (dict): Request headers.
Returns:
int: Status code of the branch creation request.
"""
branch_url = f'https://api.github.com/repos/{self.github_username}/{repository_name}/git/refs'
branch_params = {
'ref': f'refs/heads/{head_branch}',
'sha': requests.get(
f'https://api.github.com/repos/{self.github_username}/{repository_name}/git/refs/heads/{base_branch}',
headers=headers).json()['object']['sha']
}
branch_response = requests.post(branch_url, json=branch_params, headers=headers)
if branch_response.status_code == 201:
logger.info('Branch created successfully.')
elif branch_response.status_code == 422:
logger.info('Branch new-file already exists, making commits to new-file branch')
else:
logger.info('Failed to create branch:', branch_response.json()['message'])
return branch_response.status_code
def delete_file(self, repository_name, file_name, folder_path, commit_message, head_branch, headers):
"""
Deletes a file or folder from the given repository.
Args:
repository_name (str): Name of the repository.
file_name (str): Name of the file to delete.
folder_path (str): Path to the folder containing the file.
commit_message (str): Commit message.
head_branch (str): Head branch to sync.
headers (dict): Request headers.
Returns:
int: Status code of the file deletion request.
"""
file_path = self.get_file_path(file_name, folder_path)
file_url = f'https://api.github.com/repos/{self.github_username}/{repository_name}/contents/{file_path}'
file_params = {
'message': commit_message,
'sha': self.get_sha(self.github_username, repository_name, file_name, folder_path),
'branch': head_branch
}
file_response = requests.delete(file_url, json=file_params, headers=headers)
if file_response.status_code == 200:
logger.info('File or folder delete successfully.')
else:
logger.info('Failed to Delete file or folder:', file_response.json())
return file_response.status_code
def add_file(self, repository_owner, repository_name, file_name, folder_path, head_branch, base_branch, headers, commit_message, agent_id, agent_execution_id, session):
"""
Adds a file to the given repository.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
file_name (str): Name of the file to add.
folder_path (str): Path to the folder containing the file.
head_branch (str): Head branch to sync.
base_branch (str): Base branch to sync with.
Returns:
None
"""
body = self._get_file_contents(file_name, agent_id, agent_execution_id, session)
body_bytes = body.encode("ascii")
base64_bytes = base64.b64encode(body_bytes)
file_content = base64_bytes.decode("ascii")
file_path = self.get_file_path(file_name, folder_path)
file_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/contents/{file_path}'
file_params = {
'message': commit_message,
'content': file_content,
'branch': head_branch
}
file_response = requests.put(file_url, json=file_params, headers=headers)
if file_response.status_code == 201:
logger.info('File content uploaded successfully.')
elif file_response.status_code == 422:
logger.info('File already exists')
else:
logger.info('Failed to upload file content:', file_response.json()['message'])
return file_response.status_code
def create_pull_request(self, repository_owner, repository_name, head_branch, base_branch, headers):
"""
Creates a pull request in the given repository.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
head_branch (str): Head branch to sync.
base_branch (str): Base branch to sync with.
headers (dict): Request headers.
Returns:
int: Status code of the pull request creation request.
"""
pull_request_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls'
pull_request_params = {
'title': f'Pull request by {self.github_username}',
'body': 'Please review and merge this change.',
'head': f'{self.github_username}:{head_branch}', # required for cross repository only
'head_repo': repository_name, # required for cross repository only
'base': base_branch
}
pr_response = requests.post(pull_request_url, json=pull_request_params, headers=headers)
if pr_response.status_code == 201:
logger.info('Pull request created successfully.')
elif pr_response.status_code == 422:
logger.info('Added changes to already existing pull request')
else:
logger.info('Failed to create pull request:', pr_response.json()['message'])
return pr_response.status_code
def get_sha(self, repository_owner, repository_name, file_name, folder_path=None):
"""
Gets the sha of the file to be deleted.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
file_name (str): Name of the file to delete.
folder_path (str): Path to the folder containing the file.
Returns:
str: Sha of the file to be deleted.
"""
data = self.search_repo(repository_owner, repository_name, file_name, folder_path)
return data['sha']
def get_content_in_file(self, repository_owner, repository_name, file_name, folder_path=None):
"""
Gets the content of the file.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
file_name (str): Name of the file to delete.
folder_path (str): Path to the folder containing the file.
Returns:
str: Content of the file.
"""
data = self.search_repo(repository_owner, repository_name, file_name, folder_path)
file_content = data['content']
file_content_encoding = data.get('encoding')
if file_content_encoding == 'base64':
file_content = base64.b64decode(file_content).decode()
return file_content
@classmethod
def validate_github_link(cls, link: str) -> bool:
"""
Validate a GitHub link.
Returns True if the link is valid, False otherwise.
"""
# Regular expression pattern to match a GitHub link
pattern = r'^https?://(?:www\.)?github\.com/[\w\-]+/[\w\-]+$'
# Check if the link matches the pattern
if re.match(pattern, link):
return True
return False
def _get_file_contents(self, file_name, agent_id, agent_execution_id, session):
final_path = ResourceHelper().get_agent_read_resource_path(file_name,
agent=Agent.get_agent_from_id(session, agent_id),
agent_execution=AgentExecution.get_agent_execution_from_id(
session, agent_execution_id))
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
attachment_data = S3Helper().read_from_s3(final_path)
else:
with open(final_path, "r") as file:
attachment_data = file.read().decode('utf-8')
return attachment_data
def get_pull_request_content(self, repository_owner, repository_name, pull_request_number):
"""
Gets the content of a specific pull request from a GitHub repository.
Args:
repository_owner (str): Owner of the repository.
repository_name (str): Name of the repository.
pull_request_number (int): pull request id.
headers (dict): Dictionary containing the headers, usually including the Authorization token.
Returns:
dict: Dictionary containing the pull request content or None if not found.
"""
pull_request_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}'
headers = {
"Authorization": f"token {self.github_access_token}" if self.github_access_token else None,
"Content-Type": "application/vnd.github+json",
"Accept": "application/vnd.github.v3.diff",
}
response = requests.get(pull_request_url, headers=headers)
if response.status_code == 200:
logger.info('Successfully fetched pull request content.')
return response.text
elif response.status_code == 404:
logger.warning('Pull request not found.')
else:
logger.warning('Failed to fetch pull request content: ', response.text)
return None
def get_latest_commit_id_of_pull_request(self, repository_owner, repository_name, pull_request_number):
"""
Gets the latest commit id of a specific pull request from a GitHub repository.
:param repository_owner: owner
:param repository_name: repository name
:param pull_request_number: pull request id
:return:
latest commit id of the pull request
"""
url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/commits'
headers = {
"Authorization": f"token {self.github_access_token}" if self.github_access_token else None,
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
commits = response.json()
latest_commit = commits[-1] # Assuming the last commit is the latest
return latest_commit.get('sha')
else:
logger.warning(f'Failed to fetch commits for pull request: {response.json()["message"]}')
return None
def add_line_comment_to_pull_request(self, repository_owner, repository_name, pull_request_number,
commit_id, file_path, position, comment_body):
"""
Adds a line comment to a specific pull request from a GitHub repository.
:param repository_owner: owner
:param repository_name: repository name
:param pull_request_number: pull request id
:param commit_id: commit id
:param file_path: file path
:param position: position
:param comment_body: comment body
:return:
dict: Dictionary containing the comment content or None if not found.
"""
comments_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/comments'
headers = {
"Authorization": f"token {self.github_access_token}",
"Content-Type": "application/json",
"Accept": "application/vnd.github.v3+json"
}
data = {
"commit_id": commit_id,
"path": file_path,
"position": position,
"body": comment_body
}
response = requests.post(comments_url, headers=headers, json=data)
if response.status_code == 201:
logger.info('Successfully added line comment to pull request.')
return response.json()
else:
logger.warning(f'Failed to add line comment: {response.json()["message"]}')
return None
def get_pull_requests_created_in_last_x_seconds(self, repository_owner, repository_name, x_seconds):
"""
Gets the pull requests created in the last x seconds.
Args:
repository_owner (str): Owner of the repository
repository_name (str): Repository name
x_seconds (int): The number of seconds in the past to look for PRs
Returns:
list: List of pull request objects that were created in the last x seconds
"""
# Calculate the time x seconds ago
time_x_seconds_ago = datetime.utcnow() - timedelta(seconds=x_seconds)
# Convert to the ISO8601 format GitHub expects, remove milliseconds
time_x_seconds_ago_str = time_x_seconds_ago.strftime('%Y-%m-%dT%H:%M:%SZ')
# Search query
query = f'repo:{repository_owner}/{repository_name} type:pr created:>{time_x_seconds_ago_str}'
url = f'https://api.github.com/search/issues?q={query}'
headers = {
"Authorization": f"token {self.github_access_token}",
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
pull_request_urls = []
for pull_request in response.json()['items']:
pull_request_urls.append(pull_request['html_url'])
return pull_request_urls
else:
logger.warning(f'Failed to fetch PRs: {response.json()["message"]}')
return []
================================================
FILE: superagi/helper/google_calendar_creds.py
================================================
import pickle
import os
import json
import ast
from datetime import datetime
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
from google.auth.transport.requests import Request
from superagi.config.config import get_config
from googleapiclient.discovery import build
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from sqlalchemy.orm import Session
from superagi.models.tool_config import ToolConfig
from superagi.resource_manager.file_manager import FileManager
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.helper.encyption_helper import decrypt_data, is_encrypted
class GoogleCalendarCreds:
def __init__(self, session: Session):
self.session = session
def get_credentials(self, toolkit_id):
toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
organisation_id = toolkit.organisation_id
google_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first()
if google_creds:
user_id = google_creds.user_id
final_creds = json.loads(google_creds.value)
final_creds["refresh_token"] = self.fix_refresh_token(final_creds["refresh_token"])
expire_time = datetime.strptime(final_creds["expiry"], "%Y-%m-%dT%H:%M:%S.%fZ")
google_creds = self.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id).all()
client_id = ""
client_secret = ""
for credentials in google_creds:
credentials = credentials.__dict__
if credentials["key"] == "GOOGLE_CLIENT_ID":
if is_encrypted(credentials["value"]):
client_id = decrypt_data(credentials["value"])
else:
client_id = credentials["value"]
if credentials["key"] == "GOOGLE_CLIENT_SECRET":
if is_encrypted(credentials["value"]):
client_secret = decrypt_data(credentials["value"])
else:
client_secret = credentials["value"]
creds = Credentials.from_authorized_user_info(info={
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": final_creds["refresh_token"],
"scopes": "https://www.googleapis.com/auth/calendar"
})
if expire_time > datetime.utcnow():
creds.refresh(Request())
creds_json = creds.to_json()
tokens = OauthTokens().add_or_update(self.session, toolkit_id, user_id, toolkit.organisation_id, "GOOGLE_CALENDAR_OAUTH_TOKENS", str(creds_json))
else:
return {"success": False}
service = build('calendar','v3',credentials=creds)
return {"success": True, "service": service}
def fix_refresh_token(self, refresh_token):
if refresh_token.count('/') == 1:
# Find the position of '/'
slash_index = refresh_token.index('/')
# Insert one more '/' at the position
refresh_token = refresh_token[:slash_index+1] + '/' + refresh_token[slash_index+1:]
return refresh_token
================================================
FILE: superagi/helper/google_search.py
================================================
import requests
import time
from pydantic import BaseModel
from superagi.lib.logger import logger
from superagi.helper.webpage_extractor import WebpageExtractor
class GoogleSearchWrap:
def __init__(self, api_key, search_engine_id, num_results=3, num_pages=1, num_extracts=3):
"""
Initialize the GoogleSearchWrap class.
Args:
api_key (str): Google API key
search_engine_id (str): Google Search Engine ID
num_results (int): Number of results per page
num_pages (int): Number of pages to search
num_extracts (int): Number of extracts to extract from each webpage
"""
self.api_key = api_key
self.search_engine_id = search_engine_id
self.num_results = num_results
self.num_pages = num_pages
self.num_extracts = num_extracts
self.extractor = WebpageExtractor()
def search_run(self, query):
"""
Run the Google search.
Args:
query (str): The query to search for.
Returns:
list: A list of extracts from the search results.
"""
all_snippets = []
links = []
for page in range(1, self.num_pages * self.num_results, self.num_results):
url = "https://www.googleapis.com/customsearch/v1"
params = {
"key": self.api_key,
"cx": self.search_engine_id,
"q": query,
"num": self.num_results,
"start": page
}
response = requests.get(url, params=params, timeout=100)
if response.status_code == 200:
try:
json_data = response.json()
if "items" in json_data:
for item in json_data["items"]:
all_snippets.append(item["snippet"])
links.append(item["link"])
else:
logger.info("No items found in the response.")
except ValueError as e:
logger.error(f"Error while parsing JSON data: {e}")
else:
logger.error(f"Error: {response.status_code}")
return all_snippets, links, response.status_code
def get_result(self, query):
"""
Get the result of the Google search.
Args:
query (str): The query to search for.
Returns:
list: A list of extracts from the search results.
"""
snippets, links, error_code = self.search_run(query)
webpages = []
attempts = 0
while snippets == [] and attempts < 2:
attempts += 1
logger.info("Google blocked the request. Trying again...")
time.sleep(3)
snippets, links, error_code = self.search_run(query)
if links:
for i in range(0, self.num_extracts):
time.sleep(3)
content = ""
# content = self.extractor.extract_with_3k(links[i])
# attempts = 0
# while content == "" and attempts < 2:
# attempts += 1
# content = self.extractor.extract_with_3k(links[i])
content = self.extractor.extract_with_bs4(links[i])
max_length = len(' '.join(content.split(" ")[:500]))
content = content[:max_length]
attempts = 0
while content == "" and attempts < 2:
attempts += 1
content = self.extractor.extract_with_bs4(links[i])
content = content[:max_length]
webpages.append(content)
else:
snippets = []
links = []
webpages = []
return snippets, webpages, links
================================================
FILE: superagi/helper/google_serp.py
================================================
import asyncio
from typing import Any, List
import aiohttp
from superagi.config.config import get_config
from superagi.helper.webpage_extractor import WebpageExtractor
class GoogleSerpApiWrap:
def __init__(self, api_key, num_results=10, num_pages=1, num_extracts=3):
"""
Initialize the GoogleSerpApiWrap class.
Args:
api_key (str): Google API key
num_results (int): Number of results per page
num_pages (int): Number of pages to search
num_extracts (int): Number of extracts to extract from each webpage
"""
self.api_key = api_key
self.num_results = num_results
self.num_pages = num_pages
self.num_extracts = num_extracts
self.extractor = WebpageExtractor()
def search_run(self, query):
"""
Run the Google search.
Args:
query (str): The query to search for.
Returns:
list: A list of extracts from the search results.
"""
results = asyncio.run(self.fetch_serper_results(query=query))
response = self.process_response(results)
return response
async def fetch_serper_results(self,
query: str, search_type: str = "search"
) -> dict[str, Any]:
"""
Fetch the search results from the SerpApi.
Args:
query (str): The query to search for.
search_type (str): The type of search to perform.
Returns:
dict: The search results.
"""
headers = {
"X-API-KEY": self.api_key or "",
"Content-Type": "application/json",
}
params = {"q": query,}
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://google.serper.dev/{search_type}", headers=headers, params=params
) as response:
response.raise_for_status()
search_results = await response.json()
return search_results
def process_response(self, results) -> str:
"""
Process the search results.
Args:
results (dict): The search results.
Returns:
str: The processed search results.
"""
snippets: List[str] = []
links: List[str] = []
if results.get("answerBox"):
answer_values = []
answer_box = results.get("answerBox", {})
if answer_box.get("answer"):
answer_values.append(answer_box.get("answer"))
elif answer_box.get("snippet"):
answer_values.append(answer_box.get("snippet").replace("\n", " "))
elif answer_box.get("snippetHighlighted"):
answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
if len(answer_values) > 0:
snippets.append("\n".join(answer_values))
if results.get("knowledgeGraph"):
knowledge_graph = results.get("knowledgeGraph", {})
title = knowledge_graph.get("title")
entity_type = knowledge_graph.get("type")
if entity_type:
snippets.append(f"{title}: {entity_type}.")
description = knowledge_graph.get("description")
if description:
snippets.append(description)
for attribute, value in knowledge_graph.get("attributes", {}).items():
snippets.append(f"{title} {attribute}: {value}.")
for result in results["organic"][:self.num_results]:
if "snippet" in result:
snippets.append(result["snippet"])
if "link" in result and len(links) < self.num_results:
links.append(result["link"])
for attribute, value in result.get("attributes", {}).items():
snippets.append(f"{attribute}: {value}.")
if len(snippets) == 0:
return {"snippets": "No good Google Search Result was found", "links": []}
return {"links": links, "snippets": snippets}
================================================
FILE: superagi/helper/imap_email.py
================================================
import imaplib
class ImapEmail:
def imap_open(self, imap_folder, email_sender, email_password, imap_server) -> imaplib.IMAP4_SSL:
"""
Function to open an IMAP connection to the email server.
Args:
imap_folder (str): The folder to open.
email_sender (str): The email address of the sender.
email_password (str): The password of the sender.
Returns:
imaplib.IMAP4_SSL: The IMAP connection.
"""
conn = imaplib.IMAP4_SSL(imap_server)
conn.login(email_sender, email_password)
conn.select(imap_folder)
return conn
def adjust_imap_folder(self, imap_folder, email_sender) -> str:
"""
Function to adjust the IMAP folder based on the email address of the sender.
Args:
imap_folder (str): The folder to open.
email_sender (str): The email address of the sender.
Returns:
str: The adjusted IMAP folder.
"""
if "@gmail" in email_sender.lower():
if "sent" in imap_folder.lower():
return '"[Gmail]/Sent Mail"'
if "draft" in imap_folder.lower():
return '"[Gmail]/Drafts"'
return imap_folder
================================================
FILE: superagi/helper/json_cleaner.py
================================================
import json
import re
from superagi.lib.logger import logger
import json5
class JsonCleaner:
@classmethod
def clean_boolean(cls, input_str: str = ""):
"""
Clean the boolean values in the given string.
Args:
input_str (str): The string from which the json section is to be extracted.
Returns:
str: The extracted json section.
"""
input_str = re.sub(r':\s*false', ': False', input_str)
input_str = re.sub(r':\s*true', ': True', input_str)
return input_str
@classmethod
def extract_json_section(cls, input_str: str = ""):
"""
Extract the json section from the given string.
Args:
input_str (str): The string from which the json section is to be extracted.
Returns:
str: The extracted json section.
"""
try:
first_brace_index = input_str.index("{")
final_json = input_str[first_brace_index:]
last_brace_index = final_json.rindex("}")
final_json = final_json[: last_brace_index + 1]
return final_json
except ValueError:
pass
return input_str
@classmethod
def extract_json_array_section(cls, input_str: str = ""):
"""
Extract the json section from the given string.
Args:
input_str (str): The string from which the json section is to be extracted.
Returns:
str: The extracted json section.
"""
try:
first_brace_index = input_str.index("[")
final_json = input_str[first_brace_index:]
last_brace_index = final_json.rindex("]")
final_json = final_json[: last_brace_index + 1]
return final_json
except ValueError:
pass
return input_str
@classmethod
def remove_escape_sequences(cls, string):
"""
Remove escape sequences from the given string.
Args:
string (str): The string from which the escape sequences are to be removed.
Returns:
str: The string with escape sequences removed.
"""
return string.encode('utf-8').decode('unicode_escape').encode('raw_unicode_escape').decode('utf-8')
@classmethod
def balance_braces(cls, json_string: str) -> str:
"""
Balance the braces in the given json string.
Args:
json_string (str): The json string to be processed.
Returns:
str: The json string with balanced braces.
"""
open_braces_count = json_string.count('{')
closed_braces_count = json_string.count('}')
while closed_braces_count > open_braces_count:
json_string = json_string.rstrip("}")
closed_braces_count -= 1
open_braces_count = json_string.count('{')
closed_braces_count = json_string.count('}')
if open_braces_count > closed_braces_count:
json_string += '}' * (open_braces_count - closed_braces_count)
return json_string
================================================
FILE: superagi/helper/llm_loader.py
================================================
from llama_cpp import Llama
from llama_cpp import LlamaGrammar
from superagi.config.config import get_config
from superagi.lib.logger import logger
class LLMLoader:
_instance = None
_model = None
_grammar = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(LLMLoader, cls).__new__(cls)
return cls._instance
def __init__(self, context_length):
self.context_length = context_length
@property
def model(self):
if self._model is None:
try:
self._model = Llama(
model_path="/app/local_model_path", n_ctx=self.context_length, n_gpu_layers=int(get_config('GPU_LAYERS', '-1')))
except Exception as e:
logger.error(e)
return self._model
@property
def grammar(self):
if self._grammar is None:
try:
self._grammar = LlamaGrammar.from_file(
"superagi/llms/grammar/json.gbnf")
except Exception as e:
logger.error(e)
return self._grammar
================================================
FILE: superagi/helper/models_helper.py
================================================
from superagi.llms.hugging_face import HuggingFace
class ModelsHelper:
@staticmethod
def validate_end_point(model_api_key, end_point, model_provider):
response = {"success": True}
if (model_provider == 'Hugging Face'):
try:
result = HuggingFace(api_key=model_api_key, end_point=end_point).verify_end_point()
except Exception as e:
response['success'] = False
response['error'] = str(e)
else:
response['result'] = result
return response
================================================
FILE: superagi/helper/prompt_reader.py
================================================
from pathlib import Path
class PromptReader:
@staticmethod
def read_tools_prompt(current_file: str, prompt_file: str) -> str:
file_path = str(Path(current_file).resolve().parent) + "/prompts/" + prompt_file
try:
f = open(file_path, "r")
file_content = f.read()
f.close()
except FileNotFoundError as e:
print(e.__str__())
raise e
return file_content
@staticmethod
def read_agent_prompt(current_file: str, prompt_file: str) -> str:
file_path = str(Path(current_file).resolve().parent) + "/prompts/" + prompt_file
try:
f = open(file_path, "r")
file_content = f.read()
f.close()
except FileNotFoundError as e:
print(e.__str__())
raise e
return file_content
================================================
FILE: superagi/helper/read_email.py
================================================
import os
import re
from email.header import decode_header
from bs4 import BeautifulSoup
class ReadEmail:
def clean_email_body(self, email_body):
"""
Function to clean the email body.
Args:
email_body (str): The email body to be cleaned.
Returns:
str: The cleaned email body.
"""
if email_body is None: email_body = ""
email_body = BeautifulSoup(email_body, "html.parser")
email_body = email_body.get_text()
email_body = "".join(email_body.splitlines())
email_body = " ".join(email_body.split())
email_body = email_body.encode("ascii", "ignore")
email_body = email_body.decode("utf-8", "ignore")
email_body = re.sub(r"http\S+", "", email_body)
return email_body
def clean(self, text):
"""
Function to clean the text.
Args:
text (str): The text to be cleaned.
Returns:
str: The cleaned text.
"""
return "".join(c if c.isalnum() else "_" for c in text)
def obtain_header(self, msg):
"""
Function to obtain the header of the email.
Args:
msg (email.message.Message): The email message.
Returns:
str: The From field of the email.
"""
if msg["Subject"] is not None:
Subject, encoding = decode_header(msg["Subject"])[0]
else:
Subject = ""
encoding = ""
if isinstance(Subject, bytes):
try:
if encoding is not None:
Subject = Subject.decode(encoding)
else:
Subject = ""
except[LookupError] as err:
pass
From = msg["From"]
To = msg["To"]
Date = msg["Date"]
return From, To, Date, Subject
def download_attachment(self, part, subject):
"""
Function to download the attachment from the email.
Args:
part (email.message.Message): The email message.
subject (str): The subject of the email.
Returns:
None
"""
filename = part.get_filename()
if filename:
folder_name = self.clean(subject)
if not os.path.isdir(folder_name):
os.mkdir(folder_name)
filepath = os.path.join(folder_name, filename)
open(filepath, "wb").write(part.get_payload(decode=True))
================================================
FILE: superagi/helper/resource_helper.py
================================================
import os
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.resource import Resource
from superagi.types.storage_types import StorageType
class ResourceHelper:
@classmethod
def make_written_file_resource(cls, file_name: str, agent: Agent, agent_execution: AgentExecution, session):
"""
Function to create a Resource object for a written file.
Args:
file_name (str): The name of the file.
agent (Agent): Agent related to resource.
agent_execution(AgentExecution): Agent Execution related to a resource
session (Session): The database session.
Returns:
Resource: The Resource object.
"""
storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value))
file_parts = os.path.splitext(file_name)
if len(file_parts) <= 1:
file_name = file_name + ".txt"
file_extension = os.path.splitext(file_name)[1][1:]
if file_extension in ["png", "jpg", "jpeg"]:
file_type = "image/" + file_extension
elif file_extension == "txt":
file_type = "application/txt"
else:
file_type = "application/misc"
if agent is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name, agent, agent_execution)
else:
final_path = ResourceHelper.get_resource_path(file_name)
file_size = os.path.getsize(final_path)
file_path = ResourceHelper.get_agent_write_resource_path(file_name, agent, agent_execution)
logger.info("make_written_file_resource:", final_path)
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
file_path = "resources" + file_path
existing_resource = session.query(Resource).filter_by(
name=file_name,
path=file_path,
storage_type=storage_type.value,
type=file_type,
channel="OUTPUT",
agent_id=agent.id,
agent_execution_id=agent_execution.id
).first()
if existing_resource:
# Update the existing resource attributes
existing_resource.size = file_size
session.commit()
session.flush()
return existing_resource
else:
resource = Resource(
name=file_name,
path=file_path,
storage_type=storage_type.value,
size=file_size,
type=file_type,
channel="OUTPUT",
agent_id=agent.id,
agent_execution_id=agent_execution.id
)
session.add(resource)
session.commit()
return resource
@classmethod
def get_formatted_agent_level_path(cls, agent: Agent, path) -> object:
formatted_agent_name = agent.name.replace(" ", "")
return path.replace("{agent_id}", formatted_agent_name + '_' + str(agent.id))
@classmethod
def get_formatted_agent_execution_level_path(cls, agent_execution: AgentExecution, path):
formatted_agent_execution_name = agent_execution.name.replace(" ", "")
return path.replace("{agent_execution_id}", (formatted_agent_execution_name + '_' + str(agent_execution.id)))
@classmethod
def get_resource_path(cls, file_name: str):
"""Get final path of the resource.
Args:
file_name (str): The name of the file.
"""
return ResourceHelper.get_root_output_dir() + file_name
@classmethod
def get_root_output_dir(cls):
"""Get root dir of the resource.
"""
root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
else:
root_dir = os.getcwd() + "/"
return root_dir
@classmethod
def get_root_input_dir(cls):
"""Get root dir of the resource.
"""
root_dir = get_config('RESOURCES_INPUT_ROOT_DIR')
if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
else:
root_dir = os.getcwd() + "/"
return root_dir
@classmethod
def get_agent_write_resource_path(cls, file_name: str, agent: Agent, agent_execution: AgentExecution):
"""Get agent resource path to write files
Args:
file_name (str): The name of the file.
agent (Agent): The unique identifier of the agent.
agent_execution (AgentExecution): The unique identifier of the agent.
"""
root_dir = ResourceHelper.get_root_output_dir()
if agent is not None and "{agent_id}" in root_dir:
root_dir = ResourceHelper.get_formatted_agent_level_path(agent, root_dir)
if agent_execution is not None and "{agent_execution_id}" in root_dir:
root_dir = ResourceHelper.get_formatted_agent_execution_level_path(agent_execution, root_dir)
directory = os.path.dirname(root_dir)
os.makedirs(directory, exist_ok=True)
final_path = root_dir + file_name
return final_path
@staticmethod
def __check_file_path_exists(path):
return (StorageType.get_storage_type(get_config("STORAGE_TYPE",
StorageType.FILE.value)) is StorageType.S3 and
not S3Helper().check_file_exists_in_s3(path)) or (
StorageType.get_storage_type(
get_config("STORAGE_TYPE", StorageType.FILE.value)) is StorageType.FILE
and not os.path.exists(path))
@classmethod
def get_agent_read_resource_path(cls, file_name, agent: Agent, agent_execution: AgentExecution):
"""Get agent resource path to read files i.e. both input and output directory
at agent level.
Args:
file_name (str): The name of the file.
agent (Agent): The agent corresponding to resource.
agent_execution (AgentExecution): The agent execution corresponding to the resource.
"""
final_path = ResourceHelper.get_root_input_dir() + file_name
if "{agent_id}" in final_path:
final_path = ResourceHelper.get_formatted_agent_level_path(
agent=agent,
path=final_path)
output_root_dir = ResourceHelper.get_root_output_dir()
if final_path is None or cls.__check_file_path_exists(final_path):
if output_root_dir is not None:
final_path = ResourceHelper.get_root_output_dir() + file_name
if "{agent_id}" in final_path:
final_path = ResourceHelper.get_formatted_agent_level_path(
agent=agent,
path=final_path)
if "{agent_execution_id}" in final_path:
final_path = ResourceHelper.get_formatted_agent_execution_level_path(
agent_execution=agent_execution,
path=final_path)
return final_path
================================================
FILE: superagi/helper/s3_helper.py
================================================
import json
import boto3
from fastapi import HTTPException
from superagi.config.config import get_config
from superagi.lib.logger import logger
from urllib.parse import unquote
import json
class S3Helper:
def __init__(self, bucket_name=get_config("BUCKET_NAME")):
"""
Initialize the S3Helper class.
Using the AWS credentials from the configuration file, create a boto3 client.
"""
self.s3 = S3Helper.__get_s3_client()
self.bucket_name = bucket_name
@classmethod
def __get_s3_client(cls):
"""
Get an S3 client.
Returns:
s3 (S3Helper): The S3Helper object.
"""
return boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
def upload_file(self, file, path):
"""
Upload a file to S3.
Args:
file (FileStorage): The file to upload.
path (str): The path to upload the file to.
Raises:
HTTPException: If the AWS credentials are not found.
Returns:
None
"""
try:
self.s3.upload_fileobj(file, self.bucket_name, path)
logger.info("File uploaded to S3 successfully!")
except Exception:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
def check_file_exists_in_s3(self, file_path):
response = self.s3.list_objects_v2(Bucket=get_config("BUCKET_NAME"), Prefix="resources" + file_path)
return 'Contents' in response
def read_from_s3(self, file_path):
file_path = "resources" + file_path
logger.info(f"Reading file from s3: {file_path}")
response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=file_path)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
return response['Body'].read().decode('utf-8')
raise Exception(f"Error read_from_s3: {response}")
def read_binary_from_s3(self, file_path):
file_path = "resources" + file_path
logger.info(f"Reading file from s3: {file_path}")
response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=file_path)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
return response['Body'].read()
raise Exception(f"Error read_from_s3: {response}")
def get_json_file(self, path):
"""
Get a JSON file from S3.
Args:
path (str): The path to the JSON file.
Raises:
HTTPException: If the AWS credentials are not found.
Returns:
dict: The JSON file.
"""
try:
obj = self.s3.get_object(Bucket=self.bucket_name, Key=path)
s3_response = obj['Body'].read().decode('utf-8')
return json.loads(s3_response)
except:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
def delete_file(self, path):
"""
Delete a file from S3.
Args:
path (str): The path to the file to delete.
Raises:
HTTPException: If the AWS credentials are not found.
Returns:
None
"""
try:
path = "resources" + path
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
logger.info("File deleted from S3 successfully!")
except:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
def upload_file_content(self, content, file_path):
try:
self.s3.put_object(Bucket=self.bucket_name, Key=file_path, Body=content)
except:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
def get_download_url_of_resources(self, db_resources_arr):
s3 = boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
response_obj = {}
for db_resource in db_resources_arr:
response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path)
content = response["Body"].read()
bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME")
file_name = db_resource.path.split('/')[-1]
file_name = ''.join(char for char in file_name if char != "`")
object_key = f"public_resources/run_id{db_resource.agent_execution_id}/{file_name}"
s3.put_object(Bucket=bucket_name, Key=object_key, Body=content)
file_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}"
resource_execution_id = db_resource.agent_execution_id
if resource_execution_id in response_obj:
response_obj[resource_execution_id].append(file_url)
else:
response_obj[resource_execution_id] = [file_url]
return response_obj
def list_files_from_s3(self, file_path):
try:
file_path = "resources" + file_path
logger.info(f"Listing files from s3 with prefix: {file_path}")
response = self.s3.list_objects_v2(Bucket=get_config("BUCKET_NAME"), Prefix=file_path)
if 'Contents' in response:
logger.info(response['Contents'])
file_list = [obj['Key'] for obj in response['Contents']]
return file_list
else:
raise Exception(f"No contents in S3 response")
except:
raise Exception(f"Error listing files from s3")
================================================
FILE: superagi/helper/time_helper.py
================================================
from datetime import datetime
def get_time_difference(timestamp1, timestamp2):
time_format = "%Y-%m-%d %H:%M:%S.%f"
# Parse the given timestamp
parsed_timestamp1 = datetime.strptime(str(timestamp1), time_format)
parsed_timestamp2 = datetime.strptime(timestamp2, time_format)
# Calculate the time difference
time_difference = parsed_timestamp2 - parsed_timestamp1
# Convert time difference to total seconds
total_seconds = int(time_difference.total_seconds())
# Calculate years, months, days, hours, and minutes
years, seconds_remainder = divmod(total_seconds, (365 * 24 * 60 * 60)) # 1 year = 365 days * 24 hours * 60 minutes * 60 seconds
months, seconds_remainder = divmod(seconds_remainder,
(30 * 24 * 60 * 60)) # 1 month = 30 days * 24 hours * 60 minutes * 60 seconds
days, seconds_remainder = divmod(seconds_remainder, 24 * 60 * 60) # 1 day = 24 hours * 60 minutes * 60 seconds
hours, seconds_remainder = divmod(seconds_remainder, 60 * 60) # 1 hour = 60 minutes * 60 seconds
minutes, _ = divmod(seconds_remainder, 60) # 1 minute = 60 seconds
# Create a dictionary to store the time difference
time_difference_dict = {
"years": years,
"months": months,
"days": days,
"hours": hours,
"minutes": minutes
}
return time_difference_dict
def parse_interval_to_seconds(interval: str) -> int:
units = {"Minutes": 60, "Hours": 3600, "Days": 86400, "Weeks": 604800, "Months": 2592000}
interval = ' '.join(interval.split())
value, unit = interval.split(" ")
return int(value) * units[unit]
================================================
FILE: superagi/helper/token_counter.py
================================================
from typing import List
import tiktoken
from superagi.types.common import BaseMessage
from superagi.lib.logger import logger
from superagi.models.models import Models
from sqlalchemy.orm import Session
class TokenCounter:
def __init__(self, session:Session=None, organisation_id: int=None):
self.session = session
self.organisation_id = organisation_id
def token_limit(self, model: str = "gpt-3.5-turbo-0301") -> int:
"""
Function to return the token limit for a given model.
Args:
model (str): The model to return the token limit for.
Raises:
KeyError: If the model is not found.
Returns:
int: The token limit.
"""
try:
model_token_limit_dict = (Models.fetch_model_tokens(self.session, self.organisation_id))
return model_token_limit_dict[model]
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
return 8092
@staticmethod
def count_message_tokens(messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301") -> int:
"""
Function to count the number of tokens in a list of messages.
Args:
messages (List[BaseMessage]): The list of messages to count the tokens for.
model (str): The model to count the tokens for.
Raises:
KeyError: If the model is not found.
Returns:
int: The number of tokens in the messages.
"""
try:
default_tokens_per_message = 4
model_token_per_message_dict = {"gpt-3.5-turbo-0301": 4, "gpt-4-0314": 3, "gpt-3.5-turbo": 4, "gpt-4": 3,
"gpt-3.5-turbo-16k": 4, "gpt-4-32k": 3, "gpt-4-32k-0314": 3,
"models/chat-bison-001": 4}
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in model_token_per_message_dict.keys():
tokens_per_message = model_token_per_message_dict[model]
else:
tokens_per_message = default_tokens_per_message
if tokens_per_message is None:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}.\n"
" See https://github.com/openai/openai-python/blob/main/chatml.md for"
" information on how messages are converted to tokens."
)
num_tokens = 0
for message in messages:
if isinstance(message, str):
message = {'content': message}
num_tokens += tokens_per_message
num_tokens += len(encoding.encode(message['content']))
num_tokens += 3
print("tokens",num_tokens)
return num_tokens
@staticmethod
def count_text_tokens(message: str) -> int:
"""
Function to count the number of tokens in a text.
Args:
message (str): The text to count the tokens for.
Returns:
int: The number of tokens in the text.
"""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(message)) + 4
return num_tokens
================================================
FILE: superagi/helper/tool_helper.py
================================================
import importlib.util
import inspect
import json
import os
import sys
import zipfile
from urllib.parse import urlparse
import requests
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.tools.base_tool import BaseTool, ToolConfiguration
from superagi.tools.base_tool import BaseToolkit
def parse_github_url(github_url):
parts = github_url.split('/')
owner = parts[3]
repo = parts[4]
branch = "main"
return f"{owner}/{repo}/{branch}"
def download_tool(tool_url, target_folder):
parsed_url = parse_github_url(tool_url)
parts = parsed_url.split("/")
path = "/"
owner, repo, branch = parts[0], parts[1], parts[2]
archive_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
response = requests.get(archive_url)
tool_zip_file_path = os.path.join(target_folder, 'tool.zip')
with open(tool_zip_file_path, 'wb') as f:
f.write(response.content)
logger.info("Reading Zip")
with zipfile.ZipFile(tool_zip_file_path, 'r') as z:
members = [m for m in z.namelist() if m.startswith(f"{owner}-{repo}") and f"{path}" in m]
# Extract only folders in the root directory
root_folders = [member for member in members if member.count('/') > 1]
for member in root_folders:
archive_folder = f"{owner}-{repo}"
target_name = member.replace(f"{archive_folder}/", "", 1)
# Skip the unique hash folder while extracting:
segments = target_name.split('/', 1)
if len(segments) > 1:
target_name = segments[1]
else:
continue
target_path = os.path.join(target_folder, target_name)
if not target_name:
continue
if member.endswith('/'):
os.makedirs(target_path, exist_ok=True)
else:
with open(target_path, 'wb') as outfile, z.open(member) as infile:
outfile.write(infile.read())
logger.info("Donwload Success!")
os.remove(tool_zip_file_path)
def get_classes_in_file(file_path, clazz):
classes = []
module = load_module_from_file(file_path)
for name, member in inspect.getmembers(module):
if inspect.isclass(member) and issubclass(member, clazz) and member != clazz:
class_dict = {}
class_dict['class_name'] = member.__name__
class_obj = getattr(module, member.__name__)
try:
obj = class_obj()
if clazz == BaseToolkit:
get_toolkit_info(class_dict, classes, obj)
elif clazz == BaseTool:
get_tool_info(class_dict, classes, obj)
except:
class_dict = None
return classes
def get_tool_info(class_dict, classes, obj):
"""
Get tool information from an object.
"""
class_dict['tool_name'] = obj.name
class_dict['tool_description'] = obj.description
classes.append(class_dict)
def get_toolkit_info(class_dict, classes, obj):
"""
Get toolkit information from an object.
"""
class_dict['toolkit_name'] = obj.name
class_dict['toolkit_description'] = obj.description
class_dict['toolkit_tools'] = obj.get_tools()
class_dict['toolkit_keys'] = obj.get_env_keys()
classes.append(class_dict)
def load_module_from_file(file_path):
spec = importlib.util.spec_from_file_location("module_name", file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def init_tools(folder_paths, session, tool_name_to_toolkit):
# Iterate over all subfolders
for folder_path in folder_paths:
if not os.path.exists(folder_path):
continue
for folder_name in os.listdir(folder_path):
folder_dir = os.path.join(folder_path, folder_name)
# Iterate over all files in the subfolder
if not os.path.isdir(folder_dir):
continue
# sys.path.append(os.path.abspath('superagi/tools/email'))
sys.path.append(folder_dir)
for file_name in os.listdir(folder_dir):
file_path = os.path.join(folder_dir, file_name)
if file_name.endswith(".py") and not file_name.startswith("__init__"):
# Get classes
classes = get_classes_in_file(file_path=file_path, clazz=BaseTool)
update_base_tool_class_info(classes, file_name, folder_name, session, tool_name_to_toolkit)
def update_base_tool_class_info(classes, file_name, folder_name, session, tool_name_to_toolkit):
for clazz in classes:
if clazz["class_name"] is not None:
tool_name = clazz['tool_name']
tool_description = clazz['tool_description']
toolkit_id = tool_name_to_toolkit.get((tool_name, folder_name), None)
if toolkit_id is not None:
new_tool = Tool.add_or_update(session, tool_name=tool_name, folder_name=folder_name,
class_name=clazz['class_name'], file_name=file_name,
toolkit_id=tool_name_to_toolkit[(tool_name, folder_name)],
description=tool_description)
def init_toolkits(code_link, existing_toolkits, folder_paths, organisation, session):
tool_name_to_toolkit = {}
new_toolkits = []
# Iterate over all subfolders
for folder_path in folder_paths:
if not os.path.exists(folder_path):
continue
for folder_name in os.listdir(folder_path):
folder_dir = os.path.join(folder_path, folder_name)
if not os.path.isdir(folder_dir):
continue
# sys.path.append(os.path.abspath('superagi/tools/email'))
sys.path.append(folder_dir)
# Iterate over all files in the subfolder
for file_name in os.listdir(folder_dir):
file_path = os.path.join(folder_dir, file_name)
if file_name.endswith(".py") and not file_name.startswith("__init__"):
# Get classes
classes = get_classes_in_file(file_path=file_path, clazz=BaseToolkit)
tool_name_to_toolkit = update_base_toolkit_info(classes, code_link, folder_name, new_toolkits,
organisation, session, tool_name_to_toolkit)
# Delete toolkits that are not present in the updated toolkits
delete_extra_toolkit(existing_toolkits, new_toolkits, session)
return tool_name_to_toolkit
def delete_extra_toolkit(existing_toolkits, new_toolkits, session):
for toolkit in existing_toolkits:
if toolkit.name not in [new_toolkit.name for new_toolkit in new_toolkits]:
session.query(Tool).filter(Tool.toolkit_id == toolkit.id).delete()
session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit.id).delete()
session.delete(toolkit)
# Commit the changes to the database
session.commit()
def update_base_toolkit_info(classes, code_link, folder_name, new_toolkits, organisation, session,
tool_name_to_toolkit):
for clazz in classes:
if clazz["class_name"] is not None:
toolkit_name = clazz["toolkit_name"]
toolkit_description = clazz["toolkit_description"]
tools = clazz["toolkit_tools"]
tool_config_keys = clazz["toolkit_keys"]
# Create a new ToolKit object
new_toolkit = Toolkit.add_or_update(
session,
name=toolkit_name,
description=toolkit_description,
show_toolkit=True if len(tools) > 1 else False,
organisation_id=organisation.id,
tool_code_link=code_link
)
new_toolkits.append(new_toolkit)
tool_mapping = {}
# Store the tools in the database
for tool in tools:
new_tool = Tool.add_or_update(session, tool_name=tool.name, folder_name=folder_name,
class_name=None, file_name=None,
toolkit_id=new_toolkit.id, description=tool.description)
tool_mapping[tool.name, folder_name] = new_toolkit.id
tool_name_to_toolkit = {**tool_mapping, **tool_name_to_toolkit}
# Store the tools config in the database
for tool_config_key in tool_config_keys:
if isinstance(tool_config_key, ToolConfiguration):
new_config = ToolConfig.add_or_update(session, toolkit_id=new_toolkit.id,
key=tool_config_key.key,
key_type=tool_config_key.key_type,
is_required=tool_config_key.is_required,
is_secret=tool_config_key.is_secret)
else:
ToolConfig.add_or_update(session, toolkit_id=new_toolkit.id,
key = tool_config_key)
return tool_name_to_toolkit
def process_files(folder_paths, session, organisation, code_link=None):
existing_toolkits = session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
tool_name_to_toolkit = init_toolkits(code_link, existing_toolkits, folder_paths, organisation, session)
init_tools(folder_paths, session, tool_name_to_toolkit)
def get_readme_content_from_code_link(tool_code_link):
if tool_code_link is None:
return None
parsed_url = urlparse(tool_code_link)
path_parts = parsed_url.path.split("/")
# Extract username, repository, and branch from the URL
username = path_parts[1]
repository = path_parts[2]
branch = path_parts[4] if len(path_parts) > 4 else "main"
readme_url = f"https://raw.githubusercontent.com/{username}/{repository}/{branch}/README.MD"
response = requests.get(readme_url)
if response.status_code == 404:
readme_url = f"https://raw.githubusercontent.com/{username}/{repository}/{branch}/README.md"
response = requests.get(readme_url)
readme_content = response.text
return readme_content
def register_toolkits(session, organisation):
tool_paths = ["superagi/tools", "superagi/tools/external_tools"]
# if get_config("ENV", "DEV") == "PROD":
# tool_paths.append("superagi/tools/marketplace_tools")
if organisation is not None:
process_files(tool_paths, session, organisation)
logger.info(f"Toolkits Registered Successfully for Organisation ID : {organisation.id}!")
def register_marketplace_toolkits(session, organisation):
tool_paths = ["superagi/tools", "superagi/tools/external_tools","superagi/tools/marketplace_tools"]
if organisation is not None:
process_files(tool_paths, session, organisation)
logger.info(f"Marketplace Toolkits Registered Successfully for Organisation ID : {organisation.id}!")
def extract_repo_name(repo_link):
# Extract the repository name from the link
# Assuming the GitHub link format: https://github.com/username/repoName
repo_name = repo_link.rsplit('/', 1)[-1]
return repo_name
def add_tool_to_json(repo_link):
# Read the content of the tools.json file
with open('tools.json', 'r') as file:
tools_data = json.load(file)
# Extract the repository name from the link
repo_name = extract_repo_name(repo_link)
# Add a new key-value pair to the tools object
tools_data['tools'][repo_name] = repo_link
# Write the updated JSON object back to tools.json
with open('tools.json', 'w') as file:
json.dump(tools_data, file, indent=2)
def handle_tools_import():
print("Handling tools import")
tool_paths = ["superagi/tools", "superagi/tools/marketplace_tools", "superagi/tools/external_tools"]
for tool_path in tool_paths:
if not os.path.exists(tool_path):
continue
for folder_name in os.listdir(tool_path):
folder_dir = os.path.join(tool_path, folder_name)
if os.path.isdir(folder_dir):
sys.path.append(folder_dir)
def compare_tools(tool1, tool2):
fields = ["name", "description"]
return any(tool1.get(field) != tool2.get(field) for field in fields)
def compare_configs(config1, config2):
fields = ["key"]
return any(config1.get(field) != config2.get(field) for field in fields)
def compare_toolkit(toolkit1, toolkit2):
main_toolkit_fields = ["description", "show_toolkit", "name", "tool_code_link"]
toolkit_diff = any(toolkit1.get(field) != toolkit2.get(field) for field in main_toolkit_fields)
tools1 = sorted(toolkit1.get("tools", []), key=lambda tool: tool.get("name", ""))
tools2 = sorted(toolkit2.get("tools", []), key=lambda tool: tool.get("name", ""))
if len(tools1) != len(tools2):
tools_diff = True
else:
tools_diff = any(compare_tools(tool1, tool2) for tool1, tool2 in zip(tools1, tools2))
tool_configs1 = sorted(toolkit1.get("configs", []), key=lambda config: config.get("key", ""))
tool_configs2 = sorted(toolkit2.get("configs", []), key=lambda config: config.get("key", ""))
if len(tool_configs1) != len(tool_configs2):
tool_configs_diff = True
else:
tool_configs_diff = any(compare_configs(config1, config2) for config1, config2 in zip(tool_configs1,
tool_configs2))
print("toolkit_diff : ", toolkit_diff)
print("tools_diff : ", tools_diff)
print("tool_configs_diff : ", tool_configs_diff)
return toolkit_diff or tools_diff or tool_configs_diff
================================================
FILE: superagi/helper/twitter_helper.py
================================================
import os
import json
import base64
import requests
from requests_oauthlib import OAuth1
from requests_oauthlib import OAuth1Session
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
class TwitterHelper:
def get_media_ids(self, session, media_files, creds, agent_id, agent_execution_id):
media_ids = []
oauth = OAuth1(creds.api_key,
client_secret=creds.api_key_secret,
resource_owner_key=creds.oauth_token,
resource_owner_secret=creds.oauth_token_secret)
for file in media_files:
file_path = self.get_file_path(session, file, agent_id, agent_execution_id)
image_data = self._get_image_data(file_path)
b64_image = base64.b64encode(image_data)
upload_endpoint = 'https://upload.twitter.com/1.1/media/upload.json'
headers = {'Authorization': 'application/octet-stream'}
response = requests.post(upload_endpoint, headers=headers,
data={'media_data': b64_image},
auth=oauth)
ids = json.loads(response.text)['media_id']
media_ids.append(str(ids))
return media_ids
def get_file_path(self, session, file_name, agent_id, agent_execution_id):
final_path = ResourceHelper().get_agent_read_resource_path(file_name,
agent=Agent.get_agent_from_id(session, agent_id),
agent_execution=AgentExecution.get_agent_execution_from_id(
session, agent_execution_id))
return final_path
def _get_image_data(self, file_path):
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
attachment_data = S3Helper().read_binary_from_s3(file_path)
else:
with open(file_path, "rb") as file:
attachment_data = file.read()
return attachment_data
def send_tweets(self, params, creds):
tweet_endpoint = "https://api.twitter.com/2/tweets"
oauth = OAuth1Session(creds.api_key,
client_secret=creds.api_key_secret,
resource_owner_key=creds.oauth_token,
resource_owner_secret=creds.oauth_token_secret)
response = oauth.post(tweet_endpoint, json=params)
return response
def _get_image_data(self, file_path):
if get_config("STORAGE_TYPE") == StorageType.S3:
return S3Helper().read_binary_from_s3(file_path)
else:
with open(file_path, "rb") as image_file:
return image_file.read()
================================================
FILE: superagi/helper/twitter_tokens.py
================================================
import hmac
import time
import random
import base64
import hashlib
import urllib.parse
import ast
import http.client as http_client
from sqlalchemy.orm import Session
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.config.config import get_config
class Creds:
def __init__(self,api_key, api_key_secret, oauth_token, oauth_token_secret):
self.api_key = api_key
self.api_key_secret = api_key_secret
self.oauth_token = oauth_token
self.oauth_token_secret = oauth_token_secret
class TwitterTokens:
def __init__(self, session: Session):
self.session = session
def get_request_token(self,api_data):
api_key = api_data["api_key"]
api_secret_key = api_data["api_secret"]
http_method = 'POST'
base_url = 'https://api.twitter.com/oauth/request_token'
env = get_config("ENV", "DEV")
if env == "DEV":
oauth_callback = "http://localhost:3000/api/twitter/oauth-tokens"
else:
oauth_callback = "https://app.superagi.com/api/twitter/oauth-tokens"
params = {
'oauth_callback': oauth_callback,
'oauth_consumer_key': api_key,
'oauth_nonce': self.gen_nonce(),
'oauth_signature_method': 'HMAC-SHA1',
'oauth_timestamp': int(time.time()),
'oauth_version': '1.0'
}
params_sorted = sorted(params.items())
params_qs = '&'.join([f'{k}={self.percent_encode(str(v))}' for k, v in params_sorted])
base_string = f'{http_method}&{self.percent_encode(base_url)}&{self.percent_encode(params_qs)}'
signing_key = f'{self.percent_encode(api_secret_key)}&'
signature = hmac.new(signing_key.encode(), base_string.encode(), hashlib.sha1)
params['oauth_signature'] = base64.b64encode(signature.digest()).decode()
auth_header = 'OAuth ' + ', '.join([f'{k}="{self.percent_encode(str(v))}"' for k, v in params.items()])
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Authorization': auth_header
}
conn = http_client.HTTPSConnection("api.twitter.com")
conn.request("POST", "/oauth/request_token", "", headers)
res = conn.getresponse()
response_data = res.read().decode('utf-8')
conn.close()
request_token_resp = dict(urllib.parse.parse_qsl(response_data))
return request_token_resp
def percent_encode(self, val):
return urllib.parse.quote(val, safe='')
def gen_nonce(self):
nonce = ''.join([str(random.randint(0, 9)) for i in range(32)])
return nonce
def get_twitter_creds(self, toolkit_id):
toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
organisation_id = toolkit.organisation_id
twitter_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first()
twitter_creds = ast.literal_eval(twitter_creds.value)
final_creds = Creds(twitter_creds['api_key'], twitter_creds['api_key_secret'], twitter_creds['oauth_token'], twitter_creds['oauth_token_secret'])
return final_creds
================================================
FILE: superagi/helper/validate_csv.py
================================================
import csv
import pandas as pd
import chardet
from superagi.lib.logger import logger
def correct_csv_encoding(file_path):
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
if encoding != 'utf-8':
data = []
with open(file_path, 'r', encoding=encoding) as f:
reader = csv.reader(f, delimiter=',', quotechar='"')
for row in reader:
try:
data.append(row)
except Exception as e:
logger.error(f"An error occurred while processing the file: {e}")
continue
df = pd.DataFrame(data)
df.to_csv(file_path, encoding='utf-8', index=False)
logger.info("File is converted to utf-8 encoding.")
else:
logger.info("File is already in utf-8 encoding.")
================================================
FILE: superagi/helper/webhook_manager.py
================================================
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.webhooks import Webhooks
from superagi.models.webhook_events import WebhookEvents
import requests
import json
from superagi.lib.logger import logger
class WebHookManager:
def __init__(self,session):
self.session=session
def agent_status_change_callback(self, agent_execution_id, curr_status, old_status):
if curr_status=="CREATED" or agent_execution_id is None:
return
agent_id=AgentExecution.get_agent_execution_from_id(self.session,agent_execution_id).agent_id
agent=Agent.get_agent_from_id(self.session,agent_id)
org=agent.get_agent_organisation(self.session)
org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all()
for webhook_obj in org_webhooks:
if "status" in webhook_obj.filters and curr_status in webhook_obj.filters["status"]:
webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"}
error=None
request=None
status='sent'
try:
request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
except Exception as e:
logger.error(f"Exception occured in webhooks {e}")
error=str(e)
if request is not None and request.status_code not in [200,201] and error is None:
error=request.text
if error is not None:
status='Error'
webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error)
self.session.add(webhook_event)
self.session.commit()
================================================
FILE: superagi/helper/webpage_extractor.py
================================================
from io import BytesIO
from PyPDF2 import PdfFileReader
from PyPDF2 import PdfReader
import requests
import re
from requests.exceptions import RequestException
from bs4 import BeautifulSoup
from newspaper import Article, ArticleException, Config
from requests_html import HTMLSession
import time
import random
from lxml import html
from superagi.lib.logger import logger
USER_AGENTS = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/12.1.1 Safari/605.1.15",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/64.0.3282.140 Safari/537.36 Edge/17.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:72.0) Gecko/20100101 Firefox/72.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.97 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/13.0.3 Safari/605.1.15",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:80.0) Gecko/20100101 Firefox/80.0",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.87 Safari/537.36",
"Mozilla/5.0 (iPhone; CPU iPhone OS 13_3 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/13.0.4 Mobile/15E148 Safari/604.1",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:73.0) Gecko/20100101 Firefox/73.0",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.129 Safari/537.36"
]
class WebpageExtractor:
def __init__(self, num_extracts=3):
"""
Initialize the WebpageExtractor class.
"""
self.num_extracts = num_extracts
def extract_with_3k(self, url):
"""
Extract the text from a webpage using the 3k method.
Args:
url (str): The URL of the webpage to extract from.
Returns:
str: The extracted text.
"""
try:
if url.lower().endswith(".pdf"):
response = requests.get(url)
response.raise_for_status()
with BytesIO(response.content) as pdf_data:
reader = PdfReader(pdf_data)
content = " ".join([reader.getPage(i).extract_text() for i in range(reader.getNumPages())])
else:
config = Config()
config.browser_user_agent = random.choice(USER_AGENTS)
config.request_timeout = 10
session = HTMLSession()
response = session.get(url)
response.html.render(timeout=config.request_timeout)
html_content = response.html.html
article = Article(url, config=config)
article.set_html(html_content)
article.parse()
content = article.text.replace('\t', ' ').replace('\n', ' ').strip()
return content[:1500]
except ArticleException as ae:
logger.error(f"Error while extracting text from HTML (newspaper3k): {str(ae)}")
return f"Error while extracting text from HTML (newspaper3k): {str(ae)}"
except RequestException as re:
logger.error(f"Error while making the request to the URL (newspaper3k): {str(re)}")
return f"Error while making the request to the URL (newspaper3k): {str(re)}"
except Exception as e:
logger.error(f"Unknown error while extracting text from HTML (newspaper3k): {str(e)}")
return ""
def extract_with_bs4(self, url):
"""
Extract the text from a webpage using the BeautifulSoup4 method.
Args:
url (str): The URL of the webpage to extract from.
Returns:
str: The extracted text.
"""
headers = {
"User-Agent": random.choice(USER_AGENTS)
}
try:
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
for tag in soup(['script', 'style', 'nav', 'footer', 'head', 'link', 'meta', 'noscript']):
tag.decompose()
main_content_areas = soup.find_all(['main', 'article', 'section', 'div'])
if main_content_areas:
main_content = max(main_content_areas, key=lambda x: len(x.text))
content_tags = ['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']
content = ' '.join([tag.text.strip() for tag in main_content.find_all(content_tags)])
else:
content = ' '.join([tag.text.strip() for tag in soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'])])
content = re.sub(r'\t', ' ', content)
content = re.sub(r'\s+', ' ', content)
return content
elif response.status_code == 404:
return f"Error: 404. Url is invalid or does not exist. Try with valid url..."
else:
logger.error(f"Error while extracting text from HTML (bs4): {response.status_code}")
return f"Error while extracting text from HTML (bs4): {response.status_code}"
except Exception as e:
logger.error(f"Unknown error while extracting text from HTML (bs4): {str(e)}")
return ""
def extract_with_lxml(self, url):
"""
Extract the text from a webpage using the lxml method.
Args:
url (str): The URL of the webpage to extract from.
Returns:
str: The extracted text.
"""
try:
config = Config()
config.browser_user_agent = random.choice(USER_AGENTS)
config.request_timeout = 10
session = HTMLSession()
response = session.get(url)
response.html.render(timeout=config.request_timeout)
html_content = response.html.html
tree = html.fromstring(html_content)
paragraphs = tree.cssselect('p, h1, h2, h3, h4, h5, h6')
content = ' '.join([para.text_content() for para in paragraphs if para.text_content()])
content = content.replace('\t', ' ').replace('\n', ' ').strip()
return content
except ArticleException as ae:
logger.error("Error while extracting text from HTML (lxml): {str(ae)}")
return ""
except RequestException as re:
logger.error(f"Error while making the request to the URL (lxml): {str(re)}")
return ""
except Exception as e:
logger.error(f"Unknown error while extracting text from HTML (lxml): {str(e)}")
return ""
================================================
FILE: superagi/image_llms/__init__.py
================================================
================================================
FILE: superagi/image_llms/base_image_llm.py
================================================
from abc import ABC, abstractmethod
class BaseImageLlm(ABC):
@abstractmethod
def get_image_model(self):
pass
@abstractmethod
def generate_image(self, prompt: str, size: int = 512, num: int = 2):
pass
================================================
FILE: superagi/image_llms/openai_dalle.py
================================================
import openai
from superagi.config.config import get_config
from superagi.image_llms.base_image_llm import BaseImageLlm
class OpenAiDalle(BaseImageLlm):
def __init__(self, api_key, image_model=None, number_of_results=1):
"""
Args:
api_key (str): The OpenAI API key.
image_model (str): The image model.
number_of_results (int): The number of results.
"""
self.number_of_results = number_of_results
self.api_key = api_key
self.image_model = image_model
openai.api_key = api_key
openai.api_base = get_config("OPENAI_API_BASE", "https://api.openai.com/v1")
def get_image_model(self):
"""
Returns:
str: The image model.
"""
return self.image_model
def generate_image(self, prompt: str, size: int = 512):
"""
Call the OpenAI image API.
Args:
prompt (str): The prompt.
size (int): The size.
num (int): The number of images.
Returns:
dict: The response.
"""
response = openai.Image.create(
prompt=prompt,
n=self.number_of_results,
size=f"{size}x{size}"
)
return response
================================================
FILE: superagi/jobs/__init__.py
================================================
================================================
FILE: superagi/jobs/agent_executor.py
================================================
from datetime import datetime, timedelta
from sqlalchemy.orm import sessionmaker
from superagi.llms.local_llm import LocalLLM
import superagi.worker
from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler
from superagi.agent.agent_tool_step_handler import AgentToolStepHandler
from superagi.agent.agent_workflow_step_wait_handler import AgentWaitStepHandler
from superagi.agent.types.wait_step_status import AgentWorkflowStepWaitStatus
from superagi.apm.event_handler import EventHandler
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.hugging_face import HuggingFace
from superagi.llms.llm_model_factory import get_model
from superagi.llms.replicate import Replicate
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.db import connect_db
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_wait import AgentWorkflowStepWait
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.vector_factory import VectorFactory
from superagi.worker import execute_agent
from superagi.agent.types.agent_workflow_step_action_types import AgentWorkflowStepAction
from superagi.agent.types.agent_execution_status import AgentExecutionStatus
# from superagi.helper.tool_helper import get_tool_config_by_key
engine = connect_db()
Session = sessionmaker(bind=engine)
class AgentExecutor:
def execute_next_step(self, agent_execution_id):
global engine
# try:
engine.dispose()
session = Session()
try:
agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
'''Avoiding running old agent executions'''
if agent_execution and agent_execution.created_at < datetime.utcnow() - timedelta(days=1):
logger.error("Older agent execution found, skipping execution")
return
agent = session.query(Agent).filter(Agent.id == agent_execution.agent_id).first()
agent_config = Agent.fetch_configuration(session, agent.id)
if agent.is_deleted or (
agent_execution.status != AgentExecutionStatus.RUNNING.value and agent_execution.status != AgentExecutionStatus.WAITING_FOR_PERMISSION.value):
logger.error(f"Agent execution stopped. {agent.id}: {agent_execution.status}")
return
organisation = Agent.find_org_by_agent_id(session, agent_id=agent.id)
if self._check_for_max_iterations(session, organisation.id, agent_config, agent_execution_id):
logger.error(f"Agent execution stopped. Max iteration exceeded. {agent.id}: {agent_execution.status}")
return
try:
model_config = AgentConfiguration.get_model_api_key(session, agent_execution.agent_id,
agent_config["model"])
model_api_key = model_config['api_key']
model_llm_source = model_config['provider']
except Exception as e:
logger.info(f"Unable to get model config...{e}")
return
try:
memory = None
if "OpenAI" in model_llm_source:
vector_store_type = VectorStoreType.get_vector_store_type(get_config("LTM_DB", "Redis"))
memory = VectorFactory.get_vector_storage(vector_store_type, "super-agent-index1",
AgentExecutor.get_embedding(model_llm_source,
model_api_key))
except Exception as e:
logger.info(f"Unable to setup the connection...{e}")
memory = None
agent_workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.id == agent_execution.current_agent_step_id).first()
try:
self.__execute_workflow_step(agent, agent_config, agent_execution_id, agent_workflow_step, memory,
model_api_key, organisation, session)
except Exception as e:
logger.info("Exception in executing the step: {}".format(e))
superagi.worker.execute_agent.apply_async((agent_execution_id, datetime.now()), countdown=15)
return
agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
if agent_execution.status == "COMPLETED" or agent_execution.status == "WAITING_FOR_PERMISSION":
logger.info("Agent Execution is completed or waiting for permission")
session.close()
return
superagi.worker.execute_agent.apply_async((agent_execution_id, datetime.now()), countdown=2)
# superagi.worker.execute_agent.delay(agent_execution_id, datetime.now())
finally:
session.close()
engine.dispose()
def __execute_workflow_step(self, agent, agent_config, agent_execution_id, agent_workflow_step, memory,
model_api_key, organisation, session):
logger.info("Executing Workflow step : ", agent_workflow_step.action_type)
if agent_workflow_step.action_type == AgentWorkflowStepAction.TOOL.value:
tool_step_handler = AgentToolStepHandler(session,
llm=get_model(model=agent_config["model"], api_key=model_api_key,
organisation_id=organisation.id)
, agent_id=agent.id, agent_execution_id=agent_execution_id,
memory=memory)
tool_step_handler.execute_step()
elif agent_workflow_step.action_type == AgentWorkflowStepAction.ITERATION_WORKFLOW.value:
iteration_step_handler = AgentIterationStepHandler(session,
llm=get_model(model=agent_config["model"],
api_key=model_api_key,
organisation_id=organisation.id)
, agent_id=agent.id,
agent_execution_id=agent_execution_id, memory=memory)
print(get_model(model=agent_config["model"], api_key=model_api_key, organisation_id=organisation.id))
iteration_step_handler.execute_step()
elif agent_workflow_step.action_type == AgentWorkflowStepAction.WAIT_STEP.value:
(AgentWaitStepHandler(session=session, agent_id=agent.id,
agent_execution_id=agent_execution_id)
.execute_step())
@classmethod
def get_embedding(cls, model_source, model_api_key):
if "OpenAI" in model_source:
return OpenAiEmbedding(api_key=model_api_key)
if "Google" in model_source:
return GooglePalm(api_key=model_api_key)
if "Hugging" in model_source:
return HuggingFace(api_key=model_api_key)
if "Replicate" in model_source:
return Replicate(api_key=model_api_key)
if "Custom" in model_source:
return LocalLLM()
return None
def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id):
db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
if agent_config["max_iterations"] <= db_agent_execution.num_of_calls:
db_agent_execution.status = AgentExecutionStatus.ITERATION_LIMIT_EXCEEDED.value
EventHandler(session=session).create_event('run_iteration_limit_crossed',
{'agent_execution_id': db_agent_execution.id,
'name': db_agent_execution.name,
'tokens_consumed': db_agent_execution.num_of_tokens,
"calls": db_agent_execution.num_of_calls},
db_agent_execution.agent_id, organisation_id)
session.commit()
logger.info("ITERATION_LIMIT_CROSSED")
return True
return False
def execute_waiting_workflows(self):
"""Check if wait time of wait workflow step is over and can be resumed."""
session = Session()
waiting_agent_executions = session.query(AgentExecution).filter(
AgentExecution.status == AgentExecutionStatus.WAIT_STEP.value,
).all()
for agent_execution in waiting_agent_executions:
workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.id == agent_execution.current_agent_step_id).first()
step_wait = AgentWorkflowStepWait.find_by_id(session, workflow_step.action_reference_id)
if step_wait is not None:
wait_time = step_wait.delay if not None else 0
logger.info(f"Agent Execution ID: {agent_execution.id}")
logger.info(f"Wait time: {wait_time}")
logger.info(f"Wait begin time: {step_wait.wait_begin_time}")
logger.info(f"Current time: {datetime.now()}")
logger.info(f"Wait Difference : {(datetime.now() - step_wait.wait_begin_time).total_seconds()}")
if ((datetime.now() - step_wait.wait_begin_time).total_seconds() > wait_time
and step_wait.status == AgentWorkflowStepWaitStatus.WAITING.value):
agent_execution.status = AgentExecutionStatus.RUNNING.value
step_wait.status = AgentWorkflowStepWaitStatus.COMPLETED.value
session.commit()
session.flush()
AgentWaitStepHandler(session=session, agent_id=agent_execution.agent_id,
agent_execution_id=agent_execution.id).handle_next_step()
execute_agent.delay(agent_execution.id, datetime.now())
session.close()
================================================
FILE: superagi/jobs/scheduling_executor.py
================================================
import ast
from datetime import datetime
from fastapi import HTTPException
from sqlalchemy.orm import sessionmaker
from superagi.models.tool import Tool
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.worker import execute_agent
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.apm.event_handler import EventHandler
from superagi.models.knowledges import Knowledges
from superagi.models.db import connect_db
engine = connect_db()
Session = sessionmaker(bind=engine)
class ScheduledAgentExecutor:
def execute_scheduled_agent(self, agent_id: int, name: str):
"""
Performs the execution of scheduled agents
Args:
agent_id: Identifier of the agent
name: Name of the agent
"""
session = Session()
agent = session.query(Agent).get(agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
start_step = AgentWorkflow.fetch_trigger_step_id(session, agent.agent_workflow_id)
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_id, name=name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
iteration_workflow_step_id=iteration_step_id)
session.add(db_agent_execution)
session.commit()
#update status from CREATED to RUNNING
db_agent_execution.status = "RUNNING"
session.commit()
agent_execution_id = db_agent_execution.id
agent_configurations = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
for agent_config in agent_configurations:
agent_execution_config = AgentExecutionConfiguration(agent_execution_id=agent_execution_id, key=agent_config.key, value=agent_config.value)
session.add(agent_execution_config)
organisation = agent.get_agent_organisation(session)
model = session.query(AgentConfiguration.value).filter(AgentConfiguration.agent_id == agent_id).filter(AgentConfiguration.key == 'model').first()[0]
EventHandler(session=session).create_event('run_created',
{'agent_execution_id': db_agent_execution.id,
'agent_execution_name':db_agent_execution.name},
agent_id,
organisation.id if organisation else 0)
agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= session, key= 'knowledge', agent_id= agent_id)
if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
knowledge_name = Knowledges.get_knowledge_from_id(session, int(agent_execution_knowledge.value)).name
if knowledge_name is not None:
EventHandler(session=session).create_event('knowledge_picked',
{'knowledge_name': knowledge_name,
'agent_execution_id': db_agent_execution.id},
agent_id,
organisation.id if organisation else 0)
session.commit()
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
session.close()
================================================
FILE: superagi/lib/logger.py
================================================
import logging
import inspect
class CustomLogRecord(logging.LogRecord):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
frame = inspect.currentframe().f_back
while frame:
if frame.f_globals['__name__'] != __name__ and frame.f_globals['__name__'] != 'logging':
break
frame = frame.f_back
if frame:
self.filename = frame.f_code.co_filename
self.lineno = frame.f_lineno
else:
self.filename = "unknown"
self.lineno = 0
class SingletonMeta(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class Logger(metaclass=SingletonMeta):
def __init__(self, logger_name='Super AGI', log_level=logging.DEBUG):
if not hasattr(self, 'logger'):
self.logger = logging.getLogger(logger_name)
self.logger.setLevel(log_level)
self.logger.makeRecord = self._make_custom_log_record
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S %Z')
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
def _make_custom_log_record(self, name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None):
return CustomLogRecord(name, level, fn, lno, msg, args, exc_info, func=func, extra=extra, sinfo=sinfo)
def debug(self, message, *args):
self.logger.debug(message)
if args:
self.logger.debug(*args)
def info(self, message, *args):
self.logger.info(message)
if args:
self.logger.info(*args)
def warning(self, message, *args):
self.logger.warning(message)
if args:
self.logger.warning(*args)
def error(self, message, *args):
self.logger.error(message)
if args:
self.logger.error(*args)
def critical(self, message, *args):
self.logger.critical(message)
if args:
self.logger.critical(*args)
logger = Logger('Super AGI')
================================================
FILE: superagi/llms/__init__.py
================================================
================================================
FILE: superagi/llms/base_llm.py
================================================
from abc import ABC, abstractmethod
class BaseLlm(ABC):
@abstractmethod
def chat_completion(self, prompt):
pass
@abstractmethod
def get_source(self):
pass
@abstractmethod
def get_api_key(self):
pass
@abstractmethod
def get_model(self):
pass
@abstractmethod
def get_models(self):
pass
@abstractmethod
def verify_access_key(self):
pass
================================================
FILE: superagi/llms/google_palm.py
================================================
import google.generativeai as palm
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
class GooglePalm(BaseLlm):
def __init__(self, api_key, model='models/chat-bison-001', temperature=0.6, candidate_count=1, top_k=40,
top_p=0.95):
"""
Args:
api_key (str): The Google PALM API key.
model (str): The model.
temperature (float): The temperature.
candidate_count (int): The number of candidates.
top_k (int): The top k.
top_p (float): The top p.
"""
self.model = model
self.temperature = temperature
self.candidate_count = candidate_count
self.top_k = top_k
self.top_p = top_p
self.api_key = api_key
palm.configure(api_key=api_key)
def get_source(self):
return "google palm"
def get_api_key(self):
"""
Returns:
str: The API key.
"""
return self.api_key
def get_model(self):
"""
Returns:
str: The model.
"""
return self.model
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT") or 800, examples=[], context=""):
"""
Call the Google PALM chat API.
Args:
context (str): The context.
examples (list): The examples.
messages (list): The messages.
Returns:
dict: The response.
"""
prompt = "\n".join(["`" + message["role"] + "`: " + message["content"] + "" for message in messages])
# role does not yield right results in case of single step prompt
if len(messages) == 1:
prompt = messages[0]['content']
try:
# NOTE: Default chat based palm bison model has different issues. We will switch to it once it gets fixed.
final_model = "models/text-bison-001" if self.model == "models/chat-bison-001" else self.model
completion = palm.generate_text(
model=final_model,
temperature=self.temperature,
candidate_count=self.candidate_count,
top_k=self.top_k,
top_p=self.top_p,
prompt=prompt,
max_output_tokens=int(max_tokens),
)
# print(completion.result)
return {"response": completion, "content": completion.result}
except Exception as exception:
logger.info("Google palm Exception:", exception)
return {"error": "ERROR_GOOGLE_PALM", "message": "Google palm exception"}
def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = palm.list_models()
return True
except Exception as exception:
logger.info("Google palm Exception:", exception)
return False
def get_models(self):
"""
Get the models.
Returns:
list: The models.
"""
try:
models_supported = ["chat-bison-001"]
return models_supported
except Exception as exception:
logger.info("Google palm Exception:", exception)
return []
================================================
FILE: superagi/llms/grammar/json.gbnf
================================================
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
================================================
FILE: superagi/llms/hugging_face.py
================================================
import os
import requests
import json
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.llms.utils.huggingface_utils.tasks import Tasks, TaskParameters
from superagi.llms.utils.huggingface_utils.public_endpoints import ACCOUNT_VERIFICATION_URL
class HuggingFace(BaseLlm):
def __init__(
self,
api_key,
model = None ,
end_point = None,
task=Tasks.TEXT_GENERATION,
**kwargs
):
self.api_key = api_key
self.model = model
self.end_point = end_point
self.task = task
self.task_params = TaskParameters().get_params(self.task, **kwargs)
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def get_source(self):
return "hugging face"
def get_api_key(self):
"""
Returns:
str: The API key.
"""
return self.api_key
def get_model(self):
"""
The API needs a POST request with the parameter "inputs".
Returns:
response from the endpoint
"""
return self.model
def get_models(self):
"""
Returns:
str: The model.
"""
return self.model
def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
response = requests.get(ACCOUNT_VERIFICATION_URL, headers=self.headers)
# A more sophisticated check could be done here.
# Ideally we should be checking the response from the endpoint along with the status code.
# If the desired response is not received, we should return False and log the response.
return response.status_code == 200
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")):
"""
Call the HuggingFace inference API.
Args:
messages (list): The messages.
max_tokens (int): The maximum number of tokens.
Returns:
dict: The response.
"""
try:
if isinstance(messages, list):
messages = messages[0]["content"] + "\nThe response in json schema:"
params = self.task_params
if self.task == Tasks.TEXT_GENERATION:
params["max_new_tokens"] = max_tokens
params['return_full_text'] = False
payload = {
"inputs": messages,
"parameters": self.task_params,
"options": {
"use_cache": False,
"wait_for_model": True,
}
}
response = requests.post(self.end_point, headers=self.headers, data=json.dumps(payload))
completion = json.loads(response.content.decode("utf-8"))
logger.info(f"{completion=}")
if self.task == Tasks.TEXT_GENERATION:
content = completion[0]["generated_text"]
else:
content = completion[0]["answer"]
return {"response": completion, "content": content}
except Exception as exception:
logger.error(f"HF Exception: {exception}")
return {"error": "ERROR_HUGGINGFACE", "message": "HuggingFace Inference exception", "details": exception}
def verify_end_point(self):
data = json.dumps({"inputs": "validating end_point"})
response = requests.post(self.end_point, headers=self.headers, data=data)
return response.json()
================================================
FILE: superagi/llms/llm_model_factory.py
================================================
from superagi.llms.google_palm import GooglePalm
from superagi.llms.local_llm import LocalLLM
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.models_config import ModelsConfig
from superagi.models.models import Models
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs):
print("Fetching model details from database...")
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
model_instance = session.query(Models).filter(Models.org_id == organisation_id, Models.model_name == model).first()
response = session.query(ModelsConfig.provider).filter(ModelsConfig.org_id == organisation_id,
ModelsConfig.id == model_instance.model_provider_id).first()
provider_name = response.provider
session.close()
if provider_name == 'OpenAI':
print("Provider is OpenAI")
return OpenAi(model=model_instance.model_name, api_key=api_key, **kwargs)
elif provider_name == 'Replicate':
print("Provider is Replicate")
return Replicate(model=model_instance.model_name, version=model_instance.version, api_key=api_key, **kwargs)
elif provider_name == 'Google Palm':
print("Provider is Google Palm")
return GooglePalm(model=model_instance.model_name, api_key=api_key, **kwargs)
elif provider_name == 'Hugging Face':
print("Provider is Hugging Face")
return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs)
elif provider_name == 'Local LLM':
print("Provider is Local LLM")
return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length)
else:
print('Unknown provider.')
def build_model_with_api_key(provider_name, api_key):
if provider_name.lower() == 'openai':
return OpenAi(api_key=api_key)
elif provider_name.lower() == 'replicate':
return Replicate(api_key=api_key)
elif provider_name.lower() == 'google palm':
return GooglePalm(api_key=api_key)
elif provider_name.lower() == 'hugging face':
return HuggingFace(api_key=api_key)
elif provider_name.lower() == 'local llm':
return LocalLLM(api_key=api_key)
else:
print('Unknown provider.')
================================================
FILE: superagi/llms/local_llm.py
================================================
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.helper.llm_loader import LLMLoader
class LocalLLM(BaseLlm):
def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1,
frequency_penalty=0,
presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096):
"""
Args:
model (str): The model.
temperature (float): The temperature.
max_tokens (int): The maximum number of tokens.
top_p (float): The top p.
frequency_penalty (float): The frequency penalty.
presence_penalty (float): The presence penalty.
number_of_results (int): The number of results.
"""
self.model = model
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.number_of_results = number_of_results
self.context_length = context_length
llm_loader = LLMLoader(self.context_length)
self.llm_model = llm_loader.model
self.llm_grammar = llm_loader.grammar
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")):
"""
Call the chat completion.
Args:
messages (list): The messages.
max_tokens (int): The maximum number of tokens.
Returns:
dict: The response.
"""
try:
if self.llm_model is None or self.llm_grammar is None:
logger.error("Model not found.")
return {"error": "Model loading error", "message": "Model not found. Please check your model path and try again."}
else:
response = self.llm_model.create_chat_completion(messages=messages, functions=None, function_call=None, temperature=self.temperature, top_p=self.top_p,
max_tokens=int(max_tokens), presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, grammar=self.llm_grammar)
content = response["choices"][0]["message"]["content"]
logger.info(content)
return {"response": response, "content": content}
except Exception as exception:
logger.info("Exception:", exception)
return {"error": "ERROR", "message": "Error: "+str(exception)}
def get_source(self):
"""
Get the source.
Returns:
str: The source.
"""
return "Local LLM"
def get_api_key(self):
"""
Returns:
str: The API key.
"""
return self.api_key
def get_model(self):
"""
Returns:
str: The model.
"""
return self.model
def get_models(self):
"""
Returns:
list: The models.
"""
return self.model
def verify_access_key(self, api_key):
return True
================================================
FILE: superagi/llms/openai.py
================================================
import openai
from openai import APIError, InvalidRequestError
from openai.error import RateLimitError, AuthenticationError, Timeout, TryAgain
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
MAX_RETRY_ATTEMPTS = 5
MIN_WAIT = 30 # Seconds
MAX_WAIT = 300 # Seconds
def custom_retry_error_callback(retry_state):
logger.info("OpenAi Exception:", retry_state.outcome.exception())
return {"error": "ERROR_OPENAI", "message": "Open ai exception: "+str(retry_state.outcome.exception())}
class OpenAi(BaseLlm):
def __init__(self, api_key, model="gpt-4", temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1,
frequency_penalty=0,
presence_penalty=0, number_of_results=1):
"""
Args:
api_key (str): The OpenAI API key.
model (str): The model.
temperature (float): The temperature.
max_tokens (int): The maximum number of tokens.
top_p (float): The top p.
frequency_penalty (float): The frequency penalty.
presence_penalty (float): The presence penalty.
number_of_results (int): The number of results.
"""
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.number_of_results = number_of_results
self.api_key = api_key
openai.api_key = api_key
openai.api_base = get_config("OPENAI_API_BASE", "https://api.openai.com/v1")
def get_source(self):
return "openai"
def get_api_key(self):
"""
Returns:
str: The API key.
"""
return self.api_key
def get_model(self):
"""
Returns:
str: The model.
"""
return self.model
@retry(
retry=(
retry_if_exception_type(RateLimitError) |
retry_if_exception_type(Timeout) |
retry_if_exception_type(TryAgain)
),
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS), # Maximum number of retry attempts
wait=wait_random_exponential(min=MIN_WAIT, max=MAX_WAIT),
before_sleep=lambda retry_state: logger.info(f"{retry_state.outcome.exception()} (attempt {retry_state.attempt_number})"),
retry_error_callback=custom_retry_error_callback
)
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")):
"""
Call the OpenAI chat completion API.
Args:
messages (list): The messages.
max_tokens (int): The maximum number of tokens.
Returns:
dict: The response.
"""
try:
# openai.api_key = get_config("OPENAI_API_KEY")
response = openai.ChatCompletion.create(
n=self.number_of_results,
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty
)
content = response.choices[0].message["content"]
return {"response": response, "content": content}
except RateLimitError as api_error:
logger.info("OpenAi RateLimitError:", api_error)
raise RateLimitError(str(api_error))
except Timeout as timeout_error:
logger.info("OpenAi Timeout:", timeout_error)
raise Timeout(str(timeout_error))
except TryAgain as try_again_error:
logger.info("OpenAi TryAgain:", try_again_error)
raise TryAgain(str(try_again_error))
except AuthenticationError as auth_error:
logger.info("OpenAi AuthenticationError:", auth_error)
return {"error": "ERROR_AUTHENTICATION", "message": "Authentication error please check the api keys: "+str(auth_error)}
except InvalidRequestError as invalid_request_error:
logger.info("OpenAi InvalidRequestError:", invalid_request_error)
return {"error": "ERROR_INVALID_REQUEST", "message": "Openai invalid request error: "+str(invalid_request_error)}
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return {"error": "ERROR_OPENAI", "message": "Open ai exception: "+str(exception)}
def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = openai.Model.list()
return True
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return False
def get_models(self):
"""
Get the models.
Returns:
list: The models.
"""
try:
models = openai.Model.list()
models = [model["id"] for model in models["data"]]
models_supported = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
models = [model for model in models if model in models_supported]
return models
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return []
================================================
FILE: superagi/llms/replicate.py
================================================
import os
import requests
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
class Replicate(BaseLlm):
def __init__(self, api_key, model: str = None, version: str = None, max_length=1000, temperature=0.7,
candidate_count=1, top_k=40, top_p=0.95):
"""
Args:
api_key (str): The Replicate API key.
model (str): The model.
version (str): The version.
temperature (float): The temperature.
candidate_count (int): The number of candidates.
top_k (int): The top k.
top_p (float): The top p.
"""
self.model = model
self.version = version
self.temperature = temperature
self.candidate_count = candidate_count
self.top_k = top_k
self.top_p = top_p
self.api_key = api_key
self.max_length = max_length
def get_source(self):
return "replicate"
def get_api_key(self):
"""
Returns:
str: The API key.
"""
return self.api_key
def get_model(self):
"""
Returns:
str: The model.
"""
return self.model
def get_models(self):
"""
Returns:
str: The model.
"""
return self.model
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT") or 800):
"""
Call the Replicate model API.
Args:
context (str): The context.
messages (list): The messages.
Returns:
dict: The response.
"""
prompt = "\n".join([message["role"] + ": " + message["content"] + "" for message in messages])
if len(messages) == 1:
prompt = "System:" + messages[0]['content'] + "\nResponse:"
else:
prompt = prompt + "\nResponse:"
try:
os.environ["REPLICATE_API_TOKEN"] = self.api_key
import replicate
output_generator = replicate.run(
self.model + ":" + self.version,
input={"prompt": prompt, "max_length": 40000, "temperature": self.temperature,
"top_p": self.top_p}
)
final_output = ""
temp_output = []
for item in output_generator:
final_output += item
temp_output.append(item)
if not final_output:
logger.error("Replicate model didn't return any output.")
return {"error": "Replicate model didn't return any output."}
print(final_output)
print(temp_output)
logger.info("Replicate response:", final_output)
return {"response": temp_output, "content": final_output}
except Exception as exception:
logger.error('Replicate model ' + self.model + ' Exception:', exception)
return {"error": exception}
def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
headers = {"Authorization": "Token " + self.api_key}
response = requests.get("https://api.replicate.com/v1/collections", headers=headers)
# If the request is successful, status code will be 200
if response.status_code == 200:
return True
else:
return False
================================================
FILE: superagi/llms/utils/__init__.py
================================================
================================================
FILE: superagi/llms/utils/huggingface_utils/__init__.py
================================================
================================================
FILE: superagi/llms/utils/huggingface_utils/public_endpoints.py
================================================
ACCOUNT_VERIFICATION_URL = "https://huggingface.co/api/whoami-v2"
================================================
FILE: superagi/llms/utils/huggingface_utils/tasks.py
================================================
from enum import Enum
from dataclasses import dataclass
from typing import List, Dict, Union, Optional
# Define an Enum for the different tasks
class Tasks(Enum):
TEXT_GENERATION = "text-generation"
class TaskParameters:
def __init__(self) -> None:
self.params = self._generate_params()
self._validate_params()
def get_params(self, task, **kwargs) -> Dict[str, Union[int, float, bool, str]]:
# Return the task parameters and override with any kwargs
# This allows us to override the default parameters
# Ignore any parameters that are not defined for the task
params = self.params[task]
for param in kwargs:
if param in params:
params[param] = kwargs[param]
return params
def _generate_params(self):
return {
Tasks.TEXT_GENERATION: TextGenerationParameters().__dict__,
}
def _validate_params(self):
assert len(self.params) == len(Tasks), "Not all tasks have parameters defined"
for task in Tasks:
assert task in self.params, f"Task {task} does not have parameters defined"
# params = self.params[task]
# assert isinstance(params, dict), f"Task {task} parameters are not a dictionary"
# for param in params:
# assert isinstance(param, str), f"Task {task} parameter {param} is not a string"
# assert isinstance(params[param], (int, float, bool, str)), f"Task {task} parameter {param} is not a valid type"
@dataclass
class TextGenerationParameters():
"""
top_k: (Default: None).
Integer to define the top tokens considered within the sample operation to create new text.
top_p: (Default: None).
Float to define the tokens that are within the sample operation of text generation.
Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
temperature: (Default: 1.0). Float (0.0-100.0).
The temperature of the sampling operation.
1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
repetition_penalty: (Default: None). Float (0.0-100.0).
The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
max_new_tokens: (Default: None). Int (0-250).
The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated.
max_time: (Default: None). Float (0-120.0).
The amount of time in seconds that the query should take maximum.
Network can cause some overhead so it will be a soft limit.
Use that in combination with max_new_tokens for best results.
return_full_text: (Default: True). Bool.
If set to False, the return results will not contain the original query making it easier for prompting.
num_return_sequences: (Default: 1). Integer.
The number of proposition you want to be returned.
do_sample: (Optional: True). Bool.
Whether or not to use sampling, use greedy decoding otherwise.
"""
top_k: Optional[int] = None
top_p: Optional[float] = None
temperature: float = 1.0
repetition_penalty: Optional[float] = None
max_new_tokens: Optional[int] = None
max_time: Optional[float] = None
return_full_text: bool = True
num_return_sequences: int = 1
do_sample: bool = True
================================================
FILE: superagi/models/__init__.py
================================================
import glob
from os.path import basename, dirname, isfile, join
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")]
================================================
FILE: superagi/models/agent.py
================================================
from __future__ import annotations
import ast
import json
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy import or_
from superagi.lib.logger import logger
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_template import AgentTemplate
from superagi.models.agent_template_config import AgentTemplateConfig
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.workflows.agent_workflow import AgentWorkflow
class Agent(DBBaseModel):
"""
Represents an agent entity.
Attributes:
id (int): The unique identifier of the agent.
name (str): The name of the agent.
project_id (int): The identifier of the associated project.
description (str): The description of the agent.
agent_workflow_id (int): The identifier of the associated agent workflow.
is_deleted (bool): The flag associated for agent deletion
"""
__tablename__ = 'agents'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
project_id = Column(Integer)
description = Column(String)
agent_workflow_id = Column(Integer)
is_deleted = Column(Boolean, default=False)
def __repr__(self):
"""
Returns a string representation of the Agent object.
Returns:
str: String representation of the Agent.
"""
return f"Agent(id={self.id}, name='{self.name}', project_id={self.project_id}, " \
f"description='{self.description}', agent_workflow_id={self.agent_workflow_id}," \
f"is_deleted='{self.is_deleted}')"
@classmethod
def fetch_configuration(cls, session, agent_id: int):
"""
Fetches the configuration of an agent.
Args:
session: The database session object.
agent_id (int): The ID of the agent.
Returns:
dict: Parsed agent configuration.
"""
agent = session.query(Agent).filter_by(id=agent_id).first()
agent_configurations = session.query(AgentConfiguration).filter_by(
agent_id=agent_id).all()
parsed_config = {
"agent_id": agent.id,
"name": agent.name,
"project_id": agent.project_id,
"description": agent.description,
"goal": [],
"instruction": [],
"constraints": [],
"tools": [],
"exit": None,
"iteration_interval": None,
"model": None,
"permission_type": None,
"LTM_DB": None,
"memory_window": None,
"max_iterations": None,
"is_deleted": agent.is_deleted,
"knowledge": None
}
if not agent_configurations:
return parsed_config
for item in agent_configurations:
parsed_config[item.key] = cls.eval_agent_config(item.key, item.value)
return parsed_config
@classmethod
def eval_agent_config(cls, key, value):
"""
Evaluates the value of an agent configuration setting based on its key.
Args:
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
Returns:
object: The evaluated value of the configuration setting.
"""
if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB",
"resource_summary", "knowledge"]:
return value
elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]:
return int(value)
elif key in ["goal", "constraints", "instruction", "is_deleted"]:
return eval(value)
elif key == "tools":
return list(ast.literal_eval(value))
@classmethod
def create_agent_with_config(cls, db, agent_with_config):
"""
Creates a new agent with the provided configuration.
Args:
db: The database object.
agent_with_config: The object containing the agent and configuration details.
Returns:
Agent: The created agent.
"""
db_agent = Agent(name=agent_with_config.name, description=agent_with_config.description,
project_id=agent_with_config.project_id)
db.session.add(db_agent)
db.session.flush() # Flush pending changes to generate the agent's ID
db.session.commit()
agent_workflow = AgentWorkflow.find_by_name(session=db.session, name=agent_with_config.agent_workflow)
logger.info("Agent workflow:", str(agent_workflow))
db_agent.agent_workflow_id = agent_workflow.id
#
# if agent_with_config.agent_type == "Don't Maintain Task Queue":
# agent_workflow = db.session.query(AgentWorkflow).filter(AgentWorkflow.name == "Goal Based Agent").first()
# logger.info(agent_workflow)
# db_agent.agent_workflow_id = agent_workflow.id
# elif agent_with_config.agent_type == "Maintain Task Queue":
# agent_workflow = db.session.query(AgentWorkflow).filter(
# AgentWorkflow.name == "Task Queue Agent With Seed").first()
# db_agent.agent_workflow_id = agent_workflow.id
# elif agent_with_config.agent_type == "Fixed Task Queue":
# agent_workflow = db.session.query(AgentWorkflow).filter(
# AgentWorkflow.name == "Fixed Task Queue").first()
# db_agent.agent_workflow_id = agent_workflow.id
db.session.commit()
# Create Agent Configuration
agent_config_values = {
"goal": agent_with_config.goal,
"instruction": agent_with_config.instruction,
"constraints": agent_with_config.constraints,
"tools": agent_with_config.tools,
"exit": agent_with_config.exit,
"iteration_interval": agent_with_config.iteration_interval,
"model": agent_with_config.model,
"permission_type": agent_with_config.permission_type,
"LTM_DB": agent_with_config.LTM_DB,
"max_iterations": agent_with_config.max_iterations,
"user_timezone": agent_with_config.user_timezone,
"knowledge": agent_with_config.knowledge,
}
agent_configurations = [
AgentConfiguration(agent_id=db_agent.id, key=key, value=str(value))
for key, value in agent_config_values.items()
]
db.session.add_all(agent_configurations)
db.session.commit()
db.session.flush()
return db_agent
@classmethod
def create_agent_with_template_id(cls, db, project_id, agent_template):
"""
Creates a new agent using the provided agent template ID.
Args:
db: The database object.
project_id (int): The ID of the project.
agent_template: The agent template object.
Returns:
Agent: The created agent.
"""
db_agent = Agent(name=agent_template.name, description=agent_template.description,
project_id=project_id,
agent_workflow_id=agent_template.agent_workflow_id)
db.session.add(db_agent)
db.session.flush() # Flush pending changes to generate the agent's ID
db.session.commit()
configs = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == agent_template.id).all()
agent_configurations = [
AgentConfiguration(
agent_id=db_agent.id, key=config.key, value=config.value
)
for config in configs
]
db.session.add_all(agent_configurations)
db.session.commit()
db.session.flush()
return db_agent
@classmethod
def create_agent_with_marketplace_template_id(cls, db, project_id, agent_template_id):
"""
Creates a new agent using the agent template ID from the marketplace.
Args:
db: The database object.
project_id (int): The ID of the project.
agent_template_id (int): The ID of the agent template from the marketplace.
Returns:
Agent: The created agent.
"""
agent_template = AgentTemplate.fetch_marketplace_detail(agent_template_id)
# we need to create agent workflow if not present. Add it once we get org id in agent workflow
db_agent = Agent(name=agent_template["name"], description=agent_template["description"],
project_id=project_id,
agent_workflow_id=agent_template["agent_workflow_id"])
db.session.add(db_agent)
db.session.flush() # Flush pending changes to generate the agent's ID
db.session.commit()
agent_configurations = [
AgentConfiguration(agent_id=db_agent.id, key=key, value=value["value"])
for key, value in agent_template["configs"].items()
]
db.session.add_all(agent_configurations)
db.session.commit()
db.session.flush()
return db_agent
def get_agent_organisation(self, session):
"""
Get the organization of the agent.
Args:
session: The database session.
Returns:
Organization: The organization of the agent.
"""
project = session.query(Project).filter(Project.id == self.project_id).first()
return session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
@classmethod
def get_agent_from_id(cls, session, agent_id):
"""
Get Agent from agent_id
Args:
session: The database session.
agent_id(int) : Unique identifier of an Agent.
Returns:
Agent: Agent object is returned.
"""
return session.query(Agent).filter(Agent.id == agent_id).first()
@classmethod
def find_org_by_agent_id(cls, session: object, agent_id: int):
"""
Finds the organization for the given agent.
Args:
session: The database session.
agent_id: The agent id.
Returns:
Organisation: The found organization.
"""
assert session, "Session cannot be None"
agent = session.query(Agent).filter_by(id=agent_id).first()
project = session.query(Project).filter(Project.id == agent.project_id).first()
return session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
@classmethod
def get_active_agent_by_id(cls, session, agent_id: int):
db_agent = session.query(Agent).filter(Agent.id == agent_id,
or_(Agent.is_deleted == False, Agent.is_deleted is None)).first()
return db_agent
================================================
FILE: superagi/models/agent_config.py
================================================
from fastapi import HTTPException
from sqlalchemy import Column, Integer, Text, String
from typing import Union
from superagi.config.config import get_config
from superagi.helper.encyption_helper import decrypt_data
from superagi.models.base_model import DBBaseModel
from superagi.models.configuration import Configuration
from superagi.models.models_config import ModelsConfig
from superagi.types.model_source_types import ModelSourceType
from superagi.models.tool import Tool
from superagi.controllers.types.agent_execution_config import AgentRunIn
class AgentConfiguration(DBBaseModel):
"""
Agent related configurations like goals, instructions, constraints and tools are stored here
Attributes:
id (int): The unique identifier of the agent configuration.
agent_id (int): The identifier of the associated agent.
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
"""
__tablename__ = 'agent_configurations'
id = Column(Integer, primary_key=True, autoincrement=True)
agent_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the Agent Configuration object.
Returns:
str: String representation of the Agent Configuration.
"""
return f"AgentConfiguration(id={self.id}, key={self.key}, value={self.value})"
@classmethod
def update_agent_configurations_table(cls, session, agent_id: Union[int, None], updated_details: AgentRunIn):
updated_details_dict = updated_details.dict()
# Fetch existing 'toolkits' agent configuration for the given agent_id
agent_toolkits_config = session.query(AgentConfiguration).filter(
AgentConfiguration.agent_id == agent_id,
AgentConfiguration.key == 'toolkits'
).first()
if agent_toolkits_config:
agent_toolkits_config.value = str(updated_details_dict['toolkits'])
else:
agent_toolkits_config = AgentConfiguration(
agent_id=agent_id,
key='toolkits',
value=str(updated_details_dict['toolkits'])
)
session.add(agent_toolkits_config)
#Fetch existing knowledge for the given agent id and update it accordingly
knowledge_config = session.query(AgentConfiguration).filter(
AgentConfiguration.agent_id == agent_id,
AgentConfiguration.key == 'knowledge'
).first()
if knowledge_config:
knowledge_config.value = str(updated_details_dict['knowledge'])
else:
knowledge_config = AgentConfiguration(
agent_id=agent_id,
key='knowledge',
value=str(updated_details_dict['knowledge'])
)
session.add(knowledge_config)
# Fetch agent configurations
agent_configs = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
for agent_config in agent_configs:
if agent_config.key in updated_details_dict:
agent_config.value = str(updated_details_dict[agent_config.key])
# Commit the changes to the database
session.commit()
return "Details updated successfully"
@classmethod
def get_model_api_key(cls, session, agent_id: int, model: str):
"""
Get the model API key from the agent id.
Args:
session (Session): The database session
agent_id (int): The agent identifier
model (str): The model name
Returns:
str: The model API key.
"""
config_model = ModelsConfig.fetch_value_by_agent_id(session, agent_id, model)
return config_model
# selected_model_source = ModelSourceType.get_model_source_from_model(model)
# if selected_model_source.value == config_model_source:
# config_value = Configuration.fetch_value_by_agent_id(session, agent_id, "model_api_key")
# model_api_key = decrypt_data(config_value)
# return model_api_key
#
# if selected_model_source == ModelSourceType.GooglePalm:
# return get_config("PALM_API_KEY")
#
# if selected_model_source == ModelSourceType.Replicate:
# return get_config("REPLICATE_API_TOKEN")
# return get_config("OPENAI_API_KEY")
@classmethod
def get_agent_config_by_key_and_agent_id(cls, session, key: str, agent_id: int):
agent_config = session.query(AgentConfiguration).filter(
AgentConfiguration.agent_id == agent_id,
AgentConfiguration.key == key
).first()
return agent_config
================================================
FILE: superagi/models/agent_execution.py
================================================
import json
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime
from superagi.models.base_model import DBBaseModel
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
class AgentExecution(DBBaseModel):
"""
Represents single agent run
Attributes:
id (int): The unique identifier of the agent execution.
status (str): The status of the agent execution. Possible values: 'CREATED', 'RUNNING', 'PAUSED',
'COMPLETED', 'TERMINATED'.
name (str): The name of the agent execution.
agent_id (int): The identifier of the associated agent.
last_execution_time (datetime): The timestamp of the last execution time.
num_of_calls (int): The number of calls made during the execution.
num_of_tokens (int): The number of tokens used during the execution.
current_agent_step_id (int): The identifier of the current step in the execution.
"""
__tablename__ = 'agent_executions'
id = Column(Integer, primary_key=True)
status = Column(String) # like ('CREATED', 'RUNNING', 'PAUSED', 'COMPLETED', 'TERMINATED')
name = Column(String)
agent_id = Column(Integer)
last_execution_time = Column(DateTime)
num_of_calls = Column(Integer, default=0)
num_of_tokens = Column(Integer, default=0)
current_agent_step_id = Column(Integer)
permission_id = Column(Integer)
iteration_workflow_step_id = Column(Integer)
current_feed_group_id = Column(String)
last_shown_error_id = Column(Integer)
def __repr__(self):
"""
Returns a string representation of the AgentExecution object.
Returns:
str: String representation of the AgentExecution.
"""
return (
f"AgentExecution(id={self.id}, name={self.name}, status='{self.status}', "
f"last_execution_time='{self.last_execution_time}', current_agent_step_id={self.current_agent_step_id}, "
f"agent_id={self.agent_id}, num_of_calls={self.num_of_calls}, num_of_tokens={self.num_of_tokens},"
f"permission_id={self.permission_id}, iteration_workflow_step_id={self.iteration_workflow_step_id})"
)
def to_dict(self):
"""
Converts the AgentExecution object to a dictionary.
Returns:
dict: Dictionary representation of the AgentExecution.
"""
return {
'id': self.id,
'status': self.status,
'name': self.name,
'agent_id': self.agent_id,
'last_execution_time': self.last_execution_time.isoformat(),
'num_of_calls': self.num_of_calls,
'num_of_tokens': self.num_of_tokens,
'current_agent_step_id': self.current_agent_step_id,
'permission_id': self.permission_id,
'iteration_workflow_step_id': self.iteration_workflow_step_id
}
def to_json(self):
"""
Converts the AgentExecution object to a JSON string.
Returns:
str: JSON string representation of the AgentExecution.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentExecution object from a JSON string.
Args:
json_data (str): JSON string representing the AgentExecution object.
Returns:
AgentExecution: The created AgentExecution object.
"""
data = json.loads(json_data)
last_execution_time = datetime.fromisoformat(data['last_execution_time'])
return cls(
id=data['id'],
status=data['status'],
name=data['name'],
agent_id=data['agent_id'],
last_execution_time=last_execution_time,
num_of_calls=data['num_of_calls'],
num_of_tokens=data['num_of_tokens'],
current_agent_step_id=data['current_agent_step_id'],
permission_id=data['permission_id'],
iteration_workflow_step_id=data['iteration_workflow_step_id']
)
@classmethod
def get_agent_execution_from_id(cls, session, agent_execution_id):
"""
Get Agent from agent_id
Args:
session: The database session.
agent_execution_id(int) : Unique identifier of an Agent Execution.
Returns:
AgentExecution: AgentExecution object is returned.
"""
return session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
@classmethod
def find_by_id(cls, session, execution_id: int):
"""
Finds an AgentExecution by its id.
Args:
session: The database session.
id (int): The id of the AgentExecution.
Returns:
AgentExecution: The AgentExecution object.
"""
return session.query(AgentExecution).filter(AgentExecution.id == execution_id).first()
@classmethod
def update_tokens(self, session, agent_execution_id: int, total_tokens: int, new_llm_calls: int = 1):
agent_execution = session.query(AgentExecution).filter(
AgentExecution.id == agent_execution_id).first()
agent_execution.num_of_calls += new_llm_calls
agent_execution.num_of_tokens += total_tokens
session.commit()
@classmethod
def assign_next_step_id(cls, session, agent_execution_id: int, next_step_id: int):
"""Assigns next agent workflow step id to agent execution
Args:
session: The database session.
agent_execution_id (int): The id of the agent execution.
next_step_id (int): The id of the next agent workflow step.
"""
agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
agent_execution.current_agent_step_id = next_step_id
next_step = AgentWorkflowStep.find_by_id(session, next_step_id)
if next_step.action_type == "ITERATION_WORKFLOW":
trigger_step = IterationWorkflow.fetch_trigger_step_id(session, next_step.action_reference_id)
agent_execution.iteration_workflow_step_id = trigger_step.id
session.commit()
@classmethod
def get_execution_by_agent_id_and_status(cls, session, agent_id: int, status_filter: str):
db_agent_execution = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == status_filter).first()
return db_agent_execution
@classmethod
def get_all_executions_by_status_and_agent_id(cls, session, agent_id, execution_state_change_input, current_status: str):
db_execution_arr = []
if execution_state_change_input.run_ids is not None:
db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status,AgentExecution.id.in_(execution_state_change_input.run_ids)).all()
else:
db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status).all()
return db_execution_arr
@classmethod
def get_all_executions_by_filter_config(cls, session, agent_id: int, filter_config):
db_execution_query = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id)
if filter_config.run_ids is not None:
db_execution_query = db_execution_query.filter(AgentExecution.id.in_(filter_config.run_ids))
if filter_config.run_status_filter is not None and filter_config.run_status_filter in ["CREATED", "RUNNING",
"PAUSED", "COMPLETED",
"TERMINATED"]:
db_execution_query = db_execution_query.filter(AgentExecution.status == filter_config.run_status_filter)
db_execution_arr = db_execution_query.all()
return db_execution_arr
@classmethod
def validate_run_ids(cls, session, run_ids: list, organisation_id: int):
from superagi.models.agent import Agent
from superagi.models.project import Project
run_ids=list(set(run_ids))
agent_ids=session.query(AgentExecution.agent_id).filter(AgentExecution.id.in_(run_ids)).distinct().all()
agent_ids = [id for (id,) in agent_ids]
project_ids=session.query(Agent.project_id).filter(Agent.id.in_(agent_ids)).distinct().all()
project_ids = [id for (id,) in project_ids]
org_ids=session.query(Project.organisation_id).filter(Project.id.in_(project_ids)).distinct().all()
org_ids = [id for (id,) in org_ids]
if len(org_ids) > 1 or org_ids[0] != organisation_id:
raise Exception(f"one or more run IDs not found")
================================================
FILE: superagi/models/agent_execution_config.py
================================================
from sqlalchemy import Column, Integer, String, Text
from superagi.models.base_model import DBBaseModel
import ast
import json
from superagi.models.knowledges import Knowledges
from superagi.models.tool import Tool
from superagi.models.workflows.agent_workflow import AgentWorkflow
class AgentExecutionConfiguration(DBBaseModel):
"""
Agent Execution related configurations like goals, instructions are stored here
Attributes:
id (int): The unique identifier of the agent execution config.
agent_execution_id (int): The identifier of the associated agent execution.
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
"""
__tablename__ = 'agent_execution_configs'
id = Column(Integer, primary_key=True)
agent_execution_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentExecutionConfig object.
Returns:
str: String representation of the AgentTemplateConfig.
"""
return f"AgentExecutionConfig(id={self.id}, agent_execution_id='{self.agent_execution_id}', " \
f"key='{self.key}', value='{self.value}')"
@classmethod
def add_or_update_agent_execution_config(cls, session, execution, agent_execution_configs):
agent_execution_configurations = [
AgentExecutionConfiguration(agent_execution_id=execution.id, key=key, value=str(value))
for key, value in agent_execution_configs.items()
]
for agent_execution in agent_execution_configurations:
agent_execution_config = (
session.query(AgentExecutionConfiguration)
.filter(
AgentExecutionConfiguration.agent_execution_id == execution.id,
AgentExecutionConfiguration.key == agent_execution.key
)
.first()
)
if agent_execution_config:
agent_execution_config.value = str(agent_execution.value)
else:
agent_execution_config = AgentExecutionConfiguration(
agent_execution_id=execution.id,
key=agent_execution.key,
value=str(agent_execution.value)
)
session.add(agent_execution_config)
session.commit()
@classmethod
def fetch_configuration(cls, session, execution_id):
"""
Fetches the execution configuration of an agent.
Args:
session: The database session object.
execution (AgentExecution): The AgentExecution of the agent.
Returns:
dict: Parsed agent configuration.
"""
agent_configurations = session.query(AgentExecutionConfiguration).filter_by(
agent_execution_id=execution_id).all()
parsed_config = {
"goal": [],
"instruction": [],
"tools": []
}
if not agent_configurations:
return parsed_config
for item in agent_configurations:
parsed_config[item.key] = cls.eval_agent_config(item.key, item.value)
return parsed_config
@classmethod
def eval_agent_config(cls, key, value):
"""
Evaluates the value of an agent execution configuration setting based on its key.
Args:
key (str): The key of the execution configuration setting.
value (str): The value of execution configuration setting.
Returns:
object: The evaluated value of the execution configuration setting.
"""
if key == "goal" or key == "instruction" or key == "tools":
return eval(value)
@classmethod
def build_agent_execution_config(cls, session, agent, results_agent, results_agent_execution, total_calls, total_tokens):
results_agent_dict = {result.key: result.value for result in results_agent}
results_agent_execution_dict = {result.key: result.value for result in results_agent_execution}
for key, value in results_agent_execution_dict.items():
if key in results_agent_dict and value is not None:
results_agent_dict[key] = value
# Construct the response
if 'goal' in results_agent_dict:
results_agent_dict['goal'] = eval(results_agent_dict['goal'])
if "toolkits" in results_agent_dict:
results_agent_dict["toolkits"] = list(ast.literal_eval(results_agent_dict["toolkits"]))
if 'tools' in results_agent_dict:
results_agent_dict["tools"] = list(ast.literal_eval(results_agent_dict["tools"]))
tools = session.query(Tool).filter(Tool.id.in_(results_agent_dict["tools"])).all()
results_agent_dict["tools"] = tools
if 'instruction' in results_agent_dict:
results_agent_dict['instruction'] = eval(results_agent_dict['instruction'])
if 'constraints' in results_agent_dict:
results_agent_dict['constraints'] = eval(results_agent_dict['constraints'])
results_agent_dict["name"] = agent.name
agent_workflow = AgentWorkflow.find_by_id(session, agent.agent_workflow_id)
results_agent_dict["agent_workflow"] = agent_workflow.name
results_agent_dict["description"] = agent.description
results_agent_dict["calls"] = total_calls
results_agent_dict["tokens"] = total_tokens
knowledge_name = ""
if 'knowledge' in results_agent_dict and results_agent_dict['knowledge'] != 'None':
if type(results_agent_dict['knowledge'])==int:
results_agent_dict['knowledge'] = int(results_agent_dict['knowledge'])
knowledge = session.query(Knowledges).filter(Knowledges.id == results_agent_dict['knowledge']).first()
knowledge_name = knowledge.name if knowledge is not None else ""
results_agent_dict['knowledge_name'] = knowledge_name
return results_agent_dict
@classmethod
def build_scheduled_agent_execution_config(cls, session, agent, results_agent, total_calls, total_tokens):
results_agent_dict = {result.key: result.value for result in results_agent}
# Construct the response
if 'goal' in results_agent_dict:
results_agent_dict['goal'] = eval(results_agent_dict['goal'])
if "toolkits" in results_agent_dict:
results_agent_dict["toolkits"] = list(ast.literal_eval(results_agent_dict["toolkits"]))
if 'tools' in results_agent_dict:
results_agent_dict["tools"] = list(ast.literal_eval(results_agent_dict["tools"]))
tools = session.query(Tool).filter(Tool.id.in_(results_agent_dict["tools"])).all()
results_agent_dict["tools"] = tools
if 'instruction' in results_agent_dict:
results_agent_dict['instruction'] = eval(results_agent_dict['instruction'])
if 'constraints' in results_agent_dict:
results_agent_dict['constraints'] = eval(results_agent_dict['constraints'])
results_agent_dict["name"] = agent.name
agent_workflow = AgentWorkflow.find_by_id(session, agent.agent_workflow_id)
results_agent_dict["agent_workflow"] = agent_workflow.name
results_agent_dict["description"] = agent.description
results_agent_dict["calls"] = total_calls
results_agent_dict["tokens"] = total_tokens
knowledge_name = ""
if 'knowledge' in results_agent_dict and results_agent_dict['knowledge'] != 'None':
if type(results_agent_dict['knowledge'])==int:
results_agent_dict['knowledge'] = int(results_agent_dict['knowledge'])
knowledge = session.query(Knowledges).filter(Knowledges.id == results_agent_dict['knowledge']).first()
knowledge_name = knowledge.name if knowledge is not None else ""
results_agent_dict['knowledge_name'] = knowledge_name
return results_agent_dict
@classmethod
def fetch_value(cls, session, execution_id: int, key: str):
"""
Fetches the value of a specific execution configuration setting for an agent.
Args:
session: The database session object.
execution_id (int): The ID of the agent execution.
key (str): The key of the execution configuration setting.
Returns:
AgentExecutionConfiguration: The execution configuration setting object if found, else None.
"""
return session.query(AgentExecutionConfiguration).filter(
AgentExecutionConfiguration.agent_execution_id == execution_id,
AgentExecutionConfiguration.key == key).first()
================================================
FILE: superagi/models/agent_execution_feed.py
================================================
from sqlalchemy import Column, Integer, Text, String, asc
from sqlalchemy.orm import Session
from superagi.models.agent_execution import AgentExecution
from superagi.models.base_model import DBBaseModel
class AgentExecutionFeed(DBBaseModel):
"""
Feed of the agent execution.
Attributes:
id (int): The unique identifier of the agent execution feed.
agent_execution_id (int): The identifier of the associated agent execution.
agent_id (int): The identifier of the associated agent.
feed (str): The feed content.
role (str): The role of the feed entry. Possible values: 'system', 'user', or 'assistant'.
extra_info (str): Additional information related to the feed entry.
"""
__tablename__ = 'agent_execution_feeds'
id = Column(Integer, primary_key=True)
agent_execution_id = Column(Integer)
agent_id = Column(Integer)
feed = Column(Text)
role = Column(String)
extra_info = Column(String)
feed_group_id = Column(String)
error_message = Column(String)
def __repr__(self):
"""
Returns a string representation of the AgentExecutionFeed object.
Returns:
str: String representation of the AgentExecutionFeed.
"""
return f"AgentExecutionFeed(id={self.id}, " \
f"agent_execution_id={self.agent_execution_id}, " \
f"feed='{self.feed}', role='{self.role}', extra_info='{self.extra_info}', feed_group_id='{self.feed_group_id}')"
@classmethod
def get_last_tool_response(cls, session: Session, agent_execution_id: int, tool_name: str = None):
agent_execution_feeds = session.query(AgentExecutionFeed).filter(
AgentExecutionFeed.agent_execution_id == agent_execution_id,
AgentExecutionFeed.role == "system").order_by(AgentExecutionFeed.created_at.desc()).all()
for agent_execution_feed in agent_execution_feeds:
if tool_name and not agent_execution_feed.feed.startswith("Tool " + tool_name):
continue
if agent_execution_feed.feed.startswith("Tool"):
return agent_execution_feed.feed
return ""
@classmethod
def fetch_agent_execution_feeds(cls, session, agent_execution_id: int):
agent_execution = AgentExecution.find_by_id(session, agent_execution_id)
agent_feeds = session.query(AgentExecutionFeed.role, AgentExecutionFeed.feed, AgentExecutionFeed.id) \
.filter(AgentExecutionFeed.agent_execution_id == agent_execution_id,
AgentExecutionFeed.feed_group_id == agent_execution.current_feed_group_id) \
.order_by(asc(AgentExecutionFeed.created_at)) \
.all()
# return entire feed if it is not default feed. Default feed has prompt in the first 2 entries.
if agent_execution.current_feed_group_id != "DEFAULT":
return agent_feeds
else:
return agent_feeds[2:]
================================================
FILE: superagi/models/agent_execution_permission.py
================================================
from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey
from sqlalchemy.orm import relationship
from superagi.models.base_model import DBBaseModel
from superagi.models.agent_execution import AgentExecution
class AgentExecutionPermission(DBBaseModel):
"""
Agent Execution Permissions at each step to be approved or rejected by the user.
Attributes:
id (Integer): The primary key of the agent execution permission record.
agent_execution_id (Integer): The ID of the agent execution this permission record is associated with.
agent_id (Integer): The ID of the agent this permission record is associated with.
status (String): The status of the agent execution permission, APPROVED, REJECTED, or PENDING.
tool_name (String): The name of the tool or service that requires the permission.
user_feedback (Text): Any feedback provided by the user regarding the agent execution permission.
assistant_reply (Text): The reply or message sent back to the user by the assistant.
Methods:
__repr__: Returns a string representation of the AgentExecutionPermission instance.
"""
__tablename__ = 'agent_execution_permissions'
id = Column(Integer, primary_key=True)
agent_execution_id = Column(Integer)
agent_id = Column(Integer)
status = Column(String)
tool_name = Column(String)
user_feedback = Column(Text)
question = Column(Text)
assistant_reply = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentExecutionPermission instance.
"""
return f"AgentExecutionPermission(id={self.id}, " \
f"agent_execution_id={self.agent_execution_id}, " \
f"agent_id={self.agent_id}, " \
f"status={self.status}, " \
f"tool_name={self.tool_name}, " \
f"question={self.question}, " \
f"response={self.user_feedback})"
================================================
FILE: superagi/models/agent_schedule.py
================================================
from sqlalchemy import Column, Integer, String, Date, DateTime
from superagi.models.base_model import DBBaseModel
from superagi.controllers.types.agent_schedule import AgentScheduleInput
class AgentSchedule(DBBaseModel):
"""
Represents an Agent Scheduler record in the database.
Attributes:
id (Integer): The primary key of the agent scheduler record.
agent_id (Integer): The ID of the agent being scheduled.
start_time (DateTime): The date and time from which the agent is scheduled.
recurrence_interval (String): Stores "none" if not recurring,
or a time interval like '2 Weeks', '1 Month', '2 Minutes' based on input.
expiry_date (DateTime): The date and time when the agent is scheduled to stop runs.
expiry_runs (Integer): The number of runs before the agent expires.
current_runs (Integer): Number of runs executed in that schedule.
status: state in which the schedule is, "SCHEDULED" or "STOPPED" or "COMPLETED" or "TERMINATED"
Methods:
__repr__: Returns a string representation of the AgentSchedule instance.
"""
__tablename__ = 'agent_schedule'
id = Column(Integer, primary_key=True)
agent_id = Column(Integer)
start_time = Column(DateTime)
next_scheduled_time = Column(DateTime)
recurrence_interval = Column(String)
expiry_date = Column(DateTime)
expiry_runs = Column(Integer)
current_runs = Column(Integer)
status = Column(String)
def __repr__(self):
"""
Returns a string representation of the AgentSchedule instance.
"""
return f"AgentSchedule(id={self.id}, " \
f"agent_id={self.agent_id}, " \
f"start_time={self.start_time}, " \
f"next_scheduled_time={self.next_scheduled_time}, " \
f"recurrence_interval={self.recurrence_interval}, " \
f"expiry_date={self.expiry_date}, " \
f"expiry_runs={self.expiry_runs}), " \
f"current_runs={self.expiry_runs}), " \
f"status={self.status}), "
@classmethod
def save_schedule_from_config(cls, session, db_agent, schedule_config: AgentScheduleInput):
agent_schedule = AgentSchedule(
agent_id=db_agent.id,
start_time=schedule_config.start_time,
next_scheduled_time=schedule_config.start_time,
recurrence_interval=schedule_config.recurrence_interval,
expiry_date=schedule_config.expiry_date,
expiry_runs=schedule_config.expiry_runs,
current_runs=0,
status="SCHEDULED"
)
agent_schedule.agent_id = db_agent.id
session.add(agent_schedule)
session.commit()
return agent_schedule
@classmethod
def find_by_agent_id(cls, session, agent_id: int):
db_schedule=session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id).first()
return db_schedule
================================================
FILE: superagi/models/agent_template.py
================================================
import json
import requests
from sqlalchemy import Column, Integer, String, Text
from superagi.lib.logger import logger
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.base_model import DBBaseModel
from superagi.models.workflows.iteration_workflow import IterationWorkflow
marketplace_url = "https://app.superagi.com/api/"
# marketplace_url = "http://localhost:8001/"
class AgentTemplate(DBBaseModel):
"""
Preconfigured agent templates that can be used to create agents.
Attributes:
id (int): The unique identifier of the agent template.
organisation_id (int): The organization ID of the user or -1 if the template is public.
agent_workflow_id (int): The identifier of the workflow that the agent will use.
name (str): The name of the agent template.
description (str): The description of the agent template.
marketplace_template_id (int): The ID of the template in the marketplace.
"""
__tablename__ = 'agent_templates'
id = Column(Integer, primary_key=True)
organisation_id = Column(Integer)
agent_workflow_id = Column(Integer)
name = Column(String)
description = Column(Text)
marketplace_template_id = Column(Integer)
def __repr__(self):
"""
Returns a string representation of the AgentTemplate object.
Returns:
str: String representation of the AgentTemplate.
"""
return f"AgentTemplate(id={self.id}, name='{self.name}', " \
f"description='{self.description}')"
def to_dict(self):
"""
Converts the AgentTemplate object to a dictionary.
Returns:
dict: Dictionary representation of the AgentTemplate.
"""
return {
'id': self.id,
'name': self.name,
'description': self.description
}
def to_json(self):
"""
Converts the AgentTemplate object to a JSON string.
Returns:
str: JSON string representation of the AgentTemplate.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentTemplate object from a JSON string.
Args:
json_data (str): JSON string representing the AgentTemplate.
Returns:
AgentTemplate: AgentTemplate object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
name=data['name'],
description=data['description']
)
@classmethod
def main_keys(cls):
"""
Returns the main keys for fetching agent templates.
Returns:
list: List of main keys.
"""
keys_to_fetch = ["goal", "instruction", "constraints", "tools", "exit", "iteration_interval", "model",
"permission_type", "LTM_DB", "max_iterations", "knowledge"]
return keys_to_fetch
@classmethod
def fetch_marketplace_list(cls, search_str, page):
"""
Fetches a list of agent templates from the marketplace.
Args:
search_str (str): The search string to filter agent templates.
page (int): The page number of the result set.
Returns:
list: List of agent templates fetched from the marketplace.
"""
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + "agent_templates/marketplace/list?search=" + search_str + "&page=" + str(page),
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def fetch_marketplace_detail(cls, agent_template_id):
"""
Fetches the details of an agent template from the marketplace.
Args:
agent_template_id (int): The ID of the agent template.
Returns:
dict: Details of the agent template fetched from the marketplace.
"""
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + "agent_templates/marketplace/template_details/" + str(agent_template_id),
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return {}
@classmethod
def clone_agent_template_from_marketplace(cls, db, organisation_id: int, agent_template_id: int):
"""
Clones an agent template from the marketplace and saves it in the database.
Args:
db: The database object.
organisation_id (int): The organization ID.
agent_template_id (int): The ID of the agent template in the marketplace.
Returns:
AgentTemplate: The cloned agent template object.
"""
agent_template = AgentTemplate.fetch_marketplace_detail(agent_template_id)
agent_workflow = db.session.query(AgentWorkflow).filter(
AgentWorkflow.name == agent_template["agent_workflow_name"]).first()
# keeping it backward compatible
logger.info("agent_workflow:" + str(agent_template["agent_workflow_name"]))
if not agent_workflow:
workflow_id = AgentTemplate.fetch_iteration_agent_template_mapping(db.session, agent_template["agent_workflow_name"])
agent_workflow = db.session.query(AgentWorkflow).filter(AgentWorkflow.id == workflow_id).first()
template = AgentTemplate(organisation_id=organisation_id, agent_workflow_id=agent_workflow.id,
name=agent_template["name"], description=agent_template["description"],
marketplace_template_id=agent_template["id"])
db.session.add(template)
db.session.commit()
db.session.flush()
agent_configurations = []
for key, value in agent_template["configs"].items():
# Converting tool names to ids and saving it in agent configuration
agent_configurations.append(
AgentTemplateConfig(agent_template_id=template.id, key=key, value=str(value["value"])))
db.session.add_all(agent_configurations)
db.session.commit()
db.session.flush()
return template
@classmethod
def fetch_iteration_agent_template_mapping(cls, session, name):
if name == "Fixed Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Fixed Task Workflow")
return agent_workflow.id
if name == "Maintain Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Dynamic Task Workflow")
return agent_workflow.id
if name == "Don't Maintain Task Queue" or name == "Goal Based Agent":
agent_workflow = AgentWorkflow.find_by_name(session, "Goal Based Workflow")
return agent_workflow.id
@classmethod
def eval_agent_config(cls, key, value):
"""
Evaluates the value of an agent configuration key.
Args:
key (str): The key of the agent configuration.
value (str): The value of the agent configuration.
Returns:
object: The evaluated value of the agent configuration.
"""
if key in ["name", "description", "exit", "model", "permission_type", "LTM_DB"]:
return value
elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval", "knowledge"]:
if value is not None and value != 'None':
return int(value)
else:
return None
elif key == "goal" or key == "constraints" or key == "instruction":
return eval(value)
elif key == "tools":
return [str(x) for x in eval(value)]
================================================
FILE: superagi/models/agent_template_config.py
================================================
import json
from sqlalchemy import Column, Integer, String, Text
from superagi.models.base_model import DBBaseModel
class AgentTemplateConfig(DBBaseModel):
"""
Agent template related configurations like goals, instructions, constraints and tools are stored here
Attributes:
id (int): The unique identifier of the agent template config.
agent_template_id (int): The identifier of the associated agent template.
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
"""
__tablename__ = 'agent_template_configs'
id = Column(Integer, primary_key=True)
agent_template_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentTemplateConfig object.
Returns:
str: String representation of the AgentTemplateConfig.
"""
return f"AgentTemplateConfig(id={self.id}, agent_template_id='{self.agent_template_id}', " \
f"key='{self.key}', value='{self.value}')"
def to_dict(self):
"""
Converts the AgentTemplateConfig object to a dictionary.
Returns:
dict: Dictionary representation of the AgentTemplateConfig.
"""
return {
'id': self.id,
'agent_template_id': self.agent_template_id,
'key': self.key,
'value': self.value
}
def to_json(self):
"""
Converts the AgentTemplateConfig object to a JSON string.
Returns:
str: JSON string representation of the AgentTemplateConfig.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentTemplateConfig object from a JSON string.
Args:
json_data (str): JSON string representing the AgentTemplateConfig.
Returns:
AgentTemplateConfig: AgentTemplateConfig object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
agent_template_id=data['agent_template_id'],
key=data['key'],
value=data['value']
)
================================================
FILE: superagi/models/api_key.py
================================================
from sqlalchemy import Column, Integer, String, Boolean
from superagi.models.base_model import DBBaseModel
from sqlalchemy import or_
class ApiKey(DBBaseModel):
"""
Attributes:
Methods:
"""
__tablename__ = 'api_keys'
id = Column(Integer, primary_key=True)
org_id = Column(Integer)
name = Column(String)
key = Column(String)
is_expired= Column(Boolean)
@classmethod
def get_by_org_id(cls, session, org_id: int):
db_api_keys=session.query(ApiKey).filter(ApiKey.org_id==org_id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).all()
return db_api_keys
@classmethod
def get_by_id(cls, session, id: int):
db_api_key=session.query(ApiKey).filter(ApiKey.id==id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
return db_api_key
@classmethod
def delete_by_id(cls, session,id: int):
db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first()
db_api_key.is_expired = True
session.commit()
session.flush()
@classmethod
def update_api_key(cls, session, id: int, name: str):
db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first()
db_api_key.name = name
session.commit()
session.flush()
================================================
FILE: superagi/models/base_model.py
================================================
import json
from sqlalchemy import Column, DateTime, INTEGER
from sqlalchemy.orm import declarative_base
from datetime import datetime
Base = declarative_base()
class DBBaseModel(Base):
"""
DBBaseModel is an abstract base class for all SQLAlchemy ORM models ,
providing common columns and functionality.
Attributes:
created_at: Datetime column to store the timestamp about when a row is created.
updated_at: Datetime column to store the timestamp about when a row is updated.
Methods:
to_dict: Converts the current object to a dictionary.
to_json: Converts the current object to a JSON string.
from_json: Creates a new object of the class using the provided JSON data.
__repr__: Returns a string representation of the current object.
"""
__abstract__ = True
# id = Column(INTEGER,primary_key=True,autoincrement=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self):
"""
Converts the current SQLAlchemy ORM object to a dictionary representation.
Returns:
A dictionary mapping column names to their corresponding values.
"""
return {column.name: getattr(self, column.name) for column in self.__table__.columns}
def to_json(self):
"""
Converts the current SQLAlchemy ORM object to a JSON string representation.
Returns:
A JSON string representing the object with column names as keys and their corresponding values.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates a new SQLAlchemy ORM object of the class using the provided JSON data.
Args: json_data (str): A JSON string representing the object with column names as keys and their
corresponding values.
Returns:
A new SQLAlchemy ORM object of the class.
"""
return cls(**json.loads(json_data))
def __repr__(self):
"""
Returns a string representation of the current SQLAlchemy ORM object.
Returns:
A string with the format " ()".
"""
return f"{self.__class__.__name__} ({self.to_dict()})"
================================================
FILE: superagi/models/budget.py
================================================
from sqlalchemy import Column, Integer, String, Float
from superagi.models.base_model import DBBaseModel
class Budget(DBBaseModel):
"""
Model representing a budget.
Attributes:
id (Integer): The primary key of the budget.
budget (Float): The budget value.
cycle (String): The cycle of the budget.
"""
__tablename__ = 'budgets'
id = Column(Integer, primary_key=True)
budget = Column(Float)
cycle = Column(String)
def __repr__(self):
"""
Returns a string representation of the Budget object.
Returns:
str: String representation of the Budget object.
"""
return (f"Budget(id={self.id}, budget={self.budget}, "
f"cycle='{self.cycle}')")
================================================
FILE: superagi/models/call_logs.py
================================================
from sqlalchemy import Column, Integer, String
from superagi.models.base_model import DBBaseModel
class CallLogs(DBBaseModel):
"""
Represents a Model record in the database
Attributes:
id (Integer): The unique identifier of the event.
agent_execution_name (String): The name of the agent_execution.
agent_id (Integer): The unique id of the model_provider from the models_config table.
tokens_consumed (Integer): The number of tokens for a call.
tool_used (String): The tool_used for the call.
model (String): The model used for the Agent call.
org_id (Integer): The ID of the organisation.
"""
__tablename__ = 'call_logs'
id = Column(Integer, primary_key=True)
agent_execution_name = Column(String, nullable=False)
agent_id = Column(Integer, nullable=False)
tokens_consumed = Column(Integer, nullable=False)
tool_used = Column(String, nullable=False)
model = Column(String, nullable=True)
org_id = Column(Integer, nullable=False)
def __repr__(self):
"""
Returns a string representation of the CallLogs instance.
"""
return f"CallLogs(id={self.id}, agent_execution_name={self.agent_execution_name}, " \
f"agent_id={self.agent_id}, tokens_consumed={self.tokens_consumed}, " \
f"tool_used={self.tool_used}, " \
f"model={self.model}, " \
f"org_id={self.org_id})"
================================================
FILE: superagi/models/configuration.py
================================================
from fastapi import HTTPException
from sqlalchemy import Column, Integer, String,Text
from superagi.helper.encyption_helper import decrypt_data
from superagi.models.base_model import DBBaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.models_config import ModelsConfig
from superagi.models.models import Models
from superagi.helper.encyption_helper import decrypt_data
class Configuration(DBBaseModel):
"""
General org level configurations are stored here
Attributes:
id (Integer): The primary key of the configuration.
organisation_id (Integer): The ID of the organization associated with the configuration.
key (String): The configuration key.
value (Text): The configuration value.
"""
__tablename__ = 'configurations'
id = Column(Integer, primary_key=True, autoincrement=True)
organisation_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the Configuration object.
Returns:
str: String representation of the Configuration object.
"""
return f"Config(id={self.id}, organisation_id={self.organisation_id}, key={self.key}, value={self.value})"
@classmethod
def fetch_configuration(cls, session, organisation_id: int, key: str, default_value=None) -> str:
"""
Fetches the configuration of an agent.
Args:
session: The database session object.
organisation_id (int): The ID of the organisation.
key (str): The key of the configuration.
default_value (str): The default value of the configuration.
Returns:
dict: Parsed configuration.
"""
configuration = session.query(Configuration).filter_by(organisation_id=organisation_id, key=key).first()
if key == "model_api_key":
return decrypt_data(configuration.value) if configuration else default_value
else:
return configuration.value if configuration else default_value
@classmethod
def fetch_configurations(cls, session, organisation_id: int, key: str, model: str, default_value=None) -> str:
"""
Fetches the configuration of an agent.
Args:
session: The database session object.
organisation_id (int): The ID of the organisation.
key (str): The key of the configuration.
default_value (str): The default value of the configuration.
Returns:
dict: Parsed configuration.
"""
model_provider = session.query(Models).filter(Models.org_id == organisation_id, Models.model_name == model).first()
if not model_provider:
raise HTTPException(status_code=404, detail="Model provider not found")
configuration = session.query(ModelsConfig.provider, ModelsConfig.api_key).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.id == model_provider.model_provider_id).first()
if key == "model_api_key":
return decrypt_data(configuration.api_key) if configuration else default_value
else:
return configuration.provider if configuration else default_value
@classmethod
def fetch_value_by_agent_id(cls, session, agent_id: int, key: str):
"""
Fetches the configuration of an agent.
Args:
session: The database session object.
agent_id (int): The ID of the agent.
key (str): The key of the configuration.
Returns:
dict: Parsed configuration.
"""
from superagi.models.agent import Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project = session.query(Project).filter(Project.id == agent.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
organisation = session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
if not organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
config = session.query(Configuration).filter(Configuration.organisation_id == organisation.id,
Configuration.key == key).first()
if not config:
return None
return config.value if config else None
================================================
FILE: superagi/models/db.py
================================================
from sqlalchemy import create_engine
from superagi.config.config import get_config
from urllib.parse import urlparse
from superagi.lib.logger import logger
engine = None
def connect_db():
"""
Connects to the PostgreSQL database using SQLAlchemy.
Returns:
engine: The SQLAlchemy engine object representing the database connection.
"""
global engine
if engine is not None:
return engine
# Create the connection URL
db_host = get_config('DB_HOST', 'super__postgres')
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
db_url = get_config('DB_URL', None)
if db_url is None:
if db_username is None:
db_url = f'postgresql://{db_host}/{db_name}'
else:
db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
db_url = urlparse(db_url)
db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
# Create the SQLAlchemy engine
engine = create_engine(db_url,
pool_size=20, # Maximum number of database connections in the pool
max_overflow=50, # Maximum number of connections that can be created beyond the pool_size
pool_timeout=30, # Timeout value in seconds for acquiring a connection from the pool
pool_recycle=1800, # Recycle connections after this number of seconds (optional)
pool_pre_ping=False, # Enable connection health checks (optional)
)
# Test the connection
try:
connection = engine.connect()
logger.info("Connected to the database! @ " + db_url)
connection.close()
except Exception as e:
logger.error(f"Unable to connect to the database:{e}")
return engine
================================================
FILE: superagi/models/events.py
================================================
from sqlalchemy import Column, Integer, String, DateTime, Sequence
from sqlalchemy.dialects.postgresql import JSONB
from superagi.models.base_model import DBBaseModel
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class Event(DBBaseModel):
"""
Represents an Event record in the database.
Attributes:
id (Integer): The unique identifier of the event.
event_name (String): The name of the event.
event_value (Integer): The value of the event.
event_property (JSONB): The JSON object representing additional attributes of the event.
agent_id (Integer): The ID of the agent.
org_id (Integer): The ID of the organisation.
"""
__tablename__ = 'events'
id = Column(Integer, primary_key=True)
event_name = Column(String, nullable=False)
event_value = Column(Integer, nullable=False)
event_property = Column(JSONB, nullable=True)
agent_id = Column(Integer, nullable=True)
org_id = Column(Integer, nullable=True)
def __repr__(self):
"""
Returns a string representation of the Event instance.
"""
return f"Event(id={self.id}, event_name={self.event_name}, " \
f"event_value={self.event_value}, " \
f"agent_id={self.agent_id}, " \
f"org_id={self.org_id})"
================================================
FILE: superagi/models/knowledge_configs.py
================================================
from sqlalchemy import Column, Integer, Text, String
import requests
from superagi.models.base_model import DBBaseModel
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class KnowledgeConfigs(DBBaseModel):
"""
Knowledge related configurations such as model, data_type, tokenizer, chunk_size, chunk_overlap, text_splitter, etc. are stored here.
Attributes:
id (int): The unique identifier of the knowledge configuration.
knowledge_id (int): The identifier of the associated knowledge.
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
"""
__tablename__ = 'knowledge_configs'
id = Column(Integer, primary_key=True, autoincrement=True)
knowledge_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the Knowledge Configuration object.
Returns:
str: String representation of the Knowledge Configuration.
"""
return f"KnowledgeConfiguration(id={self.id}, knowledge_id={self.knowledge_id}, key={self.key}, value={self.value})"
@classmethod
def fetch_knowledge_config_details_marketplace(cls, knowledge_id: int):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/knowledge_configs/marketplace/details/{str(knowledge_id)}",
headers=headers, timeout=10)
if response.status_code == 200:
knowledge_config_data = response.json()
configs = {}
for knowledge_config in knowledge_config_data:
configs[knowledge_config["key"]] = knowledge_config["value"]
return configs
else:
return []
@classmethod
def add_update_knowledge_config(cls, session, knowledge_id, knowledge_configs):
for key, value in knowledge_configs.items():
config = KnowledgeConfigs(knowledge_id=knowledge_id, key=key, value=value)
session.add(config)
session.commit()
@classmethod
def get_knowledge_config_from_knowledge_id(cls, session, knowledge_id):
knowledge_configs = session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
configs = {}
for knowledge_config in knowledge_configs:
configs[knowledge_config.key] = knowledge_config.value
return configs
@classmethod
def delete_knowledge_config(cls, session, knowledge_id):
session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).delete()
session.commit()
@classmethod
def get_knowledge_config_from_knowledge_id(cls, session, knowledge_id):
knowledge_configs = session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
configs = {}
for knowledge_config in knowledge_configs:
configs[knowledge_config.key] = knowledge_config.value
return configs
================================================
FILE: superagi/models/knowledges.py
================================================
from __future__ import annotations
from sqlalchemy import Column, Integer, String
import requests
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class Knowledges(DBBaseModel):
"""
Represents an knowledge entity.
Attributes:
id (int): The unique identifier of the knowledge.
name (str): The name of the knowledge.
description (str): The description of the knowledge.
vector_db_index_id (int): The index associated with the knowledge.
is_deleted (int): The flag for deletion/uninstallation of a knowledge.
organisation_id (int): The identifier of the associated organisation.
"""
__tablename__ = 'knowledges'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
description = Column(String)
vector_db_index_id = Column(Integer)
organisation_id = Column(Integer)
contributed_by = Column(String)
def __repr__(self):
"""
Returns a string representation of the Knowledge object.
Returns:
str: String representation of the Knowledge.
"""
return f"Knowledge(id={self.id}, name='{self.name}', description='{self.description}', " \
f"vector_db_index_id={self.vector_db_index_id}), organisation_id={self.organisation_id}, contributed_by={self.contributed_by})"
@classmethod
def fetch_marketplace_list(cls, page):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/knowledges/marketplace/list/{str(page)}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def get_knowledge_install_details(cls, session, marketplace_knowledges, organisation):
installed_knowledges = session.query(Knowledges).filter(Knowledges.organisation_id == organisation.id).all()
for knowledge in marketplace_knowledges:
if knowledge["name"] in [installed_knowledge.name for installed_knowledge in installed_knowledges]:
knowledge["is_installed"] = True
else:
knowledge["is_installed"] = False
return marketplace_knowledges
@classmethod
def get_organisation_knowledges(cls, session, organisation):
knowledges = session.query(Knowledges).filter(Knowledges.organisation_id == organisation.id).all()
knowledge_data = []
for knowledge in knowledges:
data = {
"id": knowledge.id,
"name": knowledge.name,
"contributed_by": knowledge.contributed_by
}
knowledge_data.append(data)
return knowledge_data
@classmethod
def fetch_knowledge_details_marketplace(cls, knowledge_name):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/knowledges/marketplace/details/{knowledge_name}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def get_knowledge_from_id(cls, session, knowledge_id):
knowledge = session.query(Knowledges).filter(Knowledges.id == knowledge_id).first()
return knowledge
@classmethod
def add_update_knowledge(cls, session, knowledge_data):
knowledge = session.query(Knowledges).filter(Knowledges.id == knowledge_data["id"], Knowledges.organisation_id == knowledge_data["organisation_id"]).first()
if knowledge:
knowledge.name = knowledge_data["name"]
knowledge.description = knowledge_data["description"]
knowledge.vector_db_index_id = knowledge_data["index_id"]
else:
knowledge = Knowledges(name=knowledge_data["name"], description=knowledge_data["description"], vector_db_index_id=knowledge_data["index_id"], organisation_id=knowledge_data["organisation_id"], contributed_by=knowledge_data["contributed_by"])
session.add(knowledge)
session.commit()
return knowledge
@classmethod
def delete_knowledge(cls, session, knowledge_id):
session.query(Knowledges).filter(Knowledges.id == knowledge_id).delete()
session.commit()
@classmethod
def delete_knowledge_from_vector_index(cls, session, vector_db_index_id):
session.query(Knowledges).filter(Knowledges.vector_db_index_id == vector_db_index_id).delete()
session.commit()
================================================
FILE: superagi/models/marketplace_stats.py
================================================
from __future__ import annotations
from sqlalchemy import Column, Integer, String
import requests
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class MarketPlaceStats(DBBaseModel):
"""
Represents an knowledge entity.
Attributes:
id (int): The unique identifier of the marketplace stats.
reference_id (int): The unique identifier of the reference.
reference_name (str): The name of the reference used.
key (str): The key for the statistical value.
value (int): The value for the specified key.
"""
__tablename__ = 'marketplace_stats'
id = Column(Integer, primary_key=True, autoincrement=True)
reference_id = Column(Integer)
reference_name = Column(String)
key = Column(String)
value = Column(Integer)
def __repr__(self):
"""
Returns a string representation of the MarketplaceStats object.
"""
return f"Knowledge(id={self.id}, reference_id='{self.reference_id}', reference_name='{self.reference_name}', " \
f"key='{self.key}', value='{self.value}'"
@classmethod
def get_knowledge_installation_number(cls, knowledge_id: int):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/marketplace/knowledge/downloads/{str(knowledge_id)}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def update_knowledge_install_number(cls, session, knowledge_id, install_number):
knowledge_install_number = session.query(MarketPlaceStats).filter(MarketPlaceStats.reference_id == knowledge_id, MarketPlaceStats.reference_name == "KNOWLEDGE", MarketPlaceStats.key == "download_count").first()
if knowledge_install_number is None:
knowledge_install_number = MarketPlaceStats(reference_id=knowledge_id, reference_name="KNOWLEDGE", key="download_count", value=str(install_number))
session.add(knowledge_install_number)
else:
knowledge_install_number.value = str(install_number)
session.commit()
================================================
FILE: superagi/models/models.py
================================================
import yaml
from sqlalchemy import Column, Integer, String, and_
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.controllers.types.models_types import ModelsTypes
from superagi.helper.encyption_helper import decrypt_data
import requests, logging
from superagi.lib.logger import logger
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class Models(DBBaseModel):
"""
Represents a Model record in the database
Attributes:
id (Integer): The unique identifier of the event.
model_name (String): The name of the model.
description (String): The description for the model.
end_point (String): The end_point for the model.3001
model_provider_id (Integer): The unique id of the model_provider from the models_config table.
token_limit (Integer): The maximum number of tokens for a model.
type (Strng): The place it is added from.
version (String): The version of the replicate model.
org_id (Integer): The ID of the organisation.
model_features (String): The Features of the Model.
"""
__tablename__ = 'models'
id = Column(Integer, primary_key=True)
model_name = Column(String, nullable=False)
description = Column(String, nullable=True)
end_point = Column(String, nullable=False)
model_provider_id = Column(Integer, nullable=False)
token_limit = Column(Integer, nullable=False)
type = Column(String, nullable=False)
version = Column(String, nullable=False)
org_id = Column(Integer, nullable=False)
model_features = Column(String, nullable=False)
context_length = Column(Integer, nullable=True)
def __repr__(self):
"""
Returns a string representation of the Models instance.
"""
return f"Models(id={self.id}, model_name={self.model_name}, " \
f"end_point={self.end_point}, model_provider_id={self.model_provider_id}, " \
f"token_limit={self.token_limit}, " \
f"type={self.type}, " \
f"version={self.version}, " \
f"org_id={self.org_id}, " \
f"model_features={self.model_features})"
@classmethod
def fetch_marketplace_list(cls, page):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/models_controller/marketplace/list/{str(page)}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value):
from superagi.models.models_config import ModelsConfig
installed_models = session.query(Models).filter(Models.org_id == organisation_id).all()
model_counts_dict = dict(
session.query(Models.model_name, func.count(Models.org_id)).group_by(Models.model_name).all()
)
installed_models_dict = {model.model_name: True for model in installed_models}
for model in marketplace_models:
try:
if type == ModelsTypes.MARKETPLACE.value:
model["is_installed"] = False
else:
model["is_installed"] = installed_models_dict.get(model["model_name"], False)
model["installs"] = model_counts_dict.get(model["model_name"], 0)
except TypeError as e:
logging.error("Error Occurred: %s", e)
return marketplace_models
@classmethod
def fetch_model_tokens(cls, session, organisation_id) -> Dict[str, int]:
try:
models = session.query(
Models.model_name, Models.token_limit
).filter(
Models.org_id == organisation_id
).all()
if models:
return dict(models)
else:
return {"error": "No models found for the given organisation ID."}
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
return {"error": "Unexpected Error Occured"}
@classmethod
def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version, context_length):
from superagi.models.models_config import ModelsConfig
if not model_name:
return {"error": "Model Name is empty or undefined"}
if not description:
return {"error": "Description is empty or undefined"}
if not model_provider_id:
return {"error": "Model Provider Id is null or undefined or 0"}
if not token_limit:
return {"error": "Token Limit is null or undefined or 0"}
# Check if model_name already exists in the database
existing_model = session.query(Models).filter(Models.model_name == model_name, Models.org_id == organisation_id).first()
if existing_model:
return {"error": "Model Name already exists"}
# Get the provider of the model
if type == 'Marketplace':
model = ModelsConfig.fetch_model_by_id_marketplace(session, model_provider_id)
else:
model = ModelsConfig.fetch_model_by_id(session, organisation_id, model_provider_id)
if "error" in model:
return model # Return error message if model not found
# Check the 'provider' from ModelsConfig table
if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']:
return {"error": "End Point is empty or undefined"}
if context_length is None:
context_length = 0
try:
model = Models(
model_name=model_name,
description=description,
end_point=end_point,
token_limit=token_limit,
model_provider_id=model_provider_id,
type=type,
version=version,
org_id=organisation_id,
model_features='',
context_length=context_length
)
session.add(model)
session.commit()
session.flush()
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
return {"error": "Unexpected Error Occured"}
return {"success": "Model Details stored successfully", "model_id": model.id}
@classmethod
def api_key_from_configurations(cls, session, organisation_id):
try:
from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration
model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI",
ModelsConfig.org_id == organisation_id).first()
if model_provider is None:
configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key',
Configuration.organisation_id == organisation_id).first()
if configurations is None:
return {"error": "API Key is Missing"}
else:
model_api_key = decrypt_data(configurations.value)
model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key)
except Exception as e:
logging.error(f"Exception has been raised while checking API Key:: {e}")
@classmethod
def fetch_models(cls, session, organisation_id) -> Union[Dict[str, str], List[Dict[str, Union[str, int]]]]:
try:
from superagi.models.models_config import ModelsConfig
cls.api_key_from_configurations(session, organisation_id)
models = session.query(Models.id, Models.model_name, Models.description, ModelsConfig.provider).join(
ModelsConfig, Models.model_provider_id == ModelsConfig.id).filter(
Models.org_id == organisation_id).all()
result = []
for model in models:
result.append({
"id": model[0],
"name": model[1],
"description": model[2],
"model_provider": model[3]
})
except Exception as e:
logging.error(f"Unexpected Error Occurred: {e}")
return {"error": "Unexpected Error Occurred"}
return result
@classmethod
def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[str, Union[str, int]]:
try:
from superagi.models.models_config import ModelsConfig
model = session.query(
Models.id, Models.model_name, Models.description, Models.end_point, Models.token_limit, Models.type,
ModelsConfig.provider,
).join(
ModelsConfig, Models.model_provider_id == ModelsConfig.id
).filter(
and_(Models.org_id == organisation_id, Models.id == model_id)
).first()
if model:
return {
"id": model[0],
"name": model[1],
"description": model[2],
"end_point": model[3],
"token_limit": model[4],
"type": model[5],
"model_provider": model[6]
}
else:
return {"error": "Model with the given ID doesn't exist."}
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
return {"error": "Unexpected Error Occured"}
================================================
FILE: superagi/models/models_config.py
================================================
from sqlalchemy import Column, Integer, String, and_, distinct
from superagi.lib.logger import logger
from superagi.models.base_model import DBBaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.models import Models
from superagi.llms.openai import OpenAi
from superagi.helper.encyption_helper import encrypt_data, decrypt_data
from fastapi import HTTPException
import logging
class ModelsConfig(DBBaseModel):
"""
Represents a Model Config record in the database.
Attributes:
id (Integer): The unique identifier of the event.
provider (String): The name of the model provider.
api_key (String): The api_key for individual model providers for every Organisation
org_id (Integer): The ID of the organisation.
"""
__tablename__ = 'models_config'
id = Column(Integer, primary_key=True)
provider = Column(String, nullable=False)
api_key = Column(String, nullable=False)
org_id = Column(Integer, nullable=False)
def __repr__(self):
"""
Returns a string representation of the ModelsConfig instance.
"""
return f"ModelsConfig(id={self.id}, provider={self.provider}, " \
f"org_id={self.org_id})"
@classmethod
def fetch_value_by_agent_id(cls, session, agent_id: int, model: str):
"""
Fetches the configuration of an agent.
Args:
session: The database session object.
agent_id (int): The ID of the agent.
model (str): The model of the configuration.
Returns:
dict: Parsed configuration.
"""
from superagi.models.agent import Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
project = session.query(Project).filter(Project.id == agent.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
organisation = session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
if not organisation:
raise HTTPException(status_code=404, detail="Organisation not found")
model_provider = session.query(Models).filter(Models.org_id == organisation.id, Models.model_name == model).first()
if not model_provider:
raise HTTPException(status_code=404, detail="Model provider not found")
config = session.query(ModelsConfig.provider, ModelsConfig.api_key).filter(ModelsConfig.org_id == organisation.id, ModelsConfig.id == model_provider.model_provider_id).first()
if not config:
return None
if config.provider == 'Local LLM':
return {"provider": config.provider, "api_key": config.api_key} if config else None
return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None
@classmethod
def store_api_key(cls, session, organisation_id, model_provider, model_api_key):
existing_entry = session.query(ModelsConfig).filter(and_(ModelsConfig.org_id == organisation_id,
ModelsConfig.provider == model_provider)).first()
if existing_entry:
existing_entry.api_key = encrypt_data(model_api_key)
session.commit()
session.flush()
if model_provider == 'OpenAI':
cls.storeGptModels(session, organisation_id, existing_entry.id, model_api_key)
result = {'message': 'The API key was successfully updated'}
else:
new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider,
api_key=encrypt_data(model_api_key))
session.add(new_entry)
session.commit()
session.flush()
if model_provider == 'OpenAI':
cls.storeGptModels(session, organisation_id, new_entry.id, model_api_key)
result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id}
return result
@classmethod
def storeGptModels(cls, session, organisation_id, model_provider_id, model_api_key):
default_models = {"gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-3.5-turbo-16k": 16184}
models = OpenAi(api_key=model_api_key).get_models()
installed_models = [model[0] for model in session.query(Models.model_name).filter(Models.org_id == organisation_id).all()]
for model in models:
if model not in installed_models and model in default_models:
result = Models.store_model_details(session, organisation_id, model, model, '',
model_provider_id, default_models[model], 'Custom', '', 0)
@classmethod
def fetch_api_keys(cls, session, organisation_id):
api_key_info = session.query(ModelsConfig.provider, ModelsConfig.api_key).filter(
ModelsConfig.org_id == organisation_id).all()
if not api_key_info:
logging.error("No API key found for the provided model provider")
return []
api_keys = [{"provider": provider, "api_key": decrypt_data(api_key)} for provider, api_key in api_key_info if api_key != 'EMPTY']
return api_keys
@classmethod
def fetch_api_key(cls, session, organisation_id, model_provider):
api_key_data = session.query(ModelsConfig.id, ModelsConfig.provider, ModelsConfig.api_key).filter(
and_(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == model_provider)).first()
logger.info(api_key_data)
if api_key_data is None:
return []
elif api_key_data.provider == 'Local LLM':
api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
'api_key': api_key_data.api_key}]
return api_key
else:
api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
'api_key': decrypt_data(api_key_data.api_key)}]
return api_key
@classmethod
def fetch_model_by_id(cls, session, organisation_id, model_provider_id):
model = session.query(ModelsConfig.provider).filter(ModelsConfig.id == model_provider_id,
ModelsConfig.org_id == organisation_id).first()
if model is None:
return {"error": "Model not found"}
else:
return {"provider": model.provider}
@classmethod
def fetch_model_by_id_marketplace(cls, session, model_provider_id):
model = session.query(ModelsConfig.provider).filter(ModelsConfig.id == model_provider_id).first()
if model is None:
return {"error": "Model not found"}
else:
return {"provider": model.provider}
@classmethod
def add_llm_config(cls, session, organisation_id):
existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == 'Local LLM').first()
if existing_models_config is None:
models_config = ModelsConfig(org_id=organisation_id, provider='Local LLM', api_key="EMPTY")
session.add(models_config)
session.commit()
================================================
FILE: superagi/models/oauth_tokens.py
================================================
from sqlalchemy import Column, Integer, String, Text
from sqlalchemy.orm import Session
from superagi.models.base_model import DBBaseModel
import json
import yaml
class OauthTokens(DBBaseModel):
"""
Model representing a OauthTokens.
Attributes:
id (Integer): The primary key of the oauth token.
user_id (Integer): The ID of the user associated with the Tokens.
toolkit_id (Integer): The ID of the toolkit associated with the Tokens.
key (String): The Token Key.
value (Text): The Token value.
"""
__tablename__ = 'oauth_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer)
organisation_id = Column(Integer)
toolkit_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the OauthTokens object.
Returns:
str: String representation of the OauthTokens object.
"""
return f"Tokens(id={self.id}, user_id={self.user_id}, organisation_id={self.organisation_id} toolkit_id={self.toolkit_id}, key={self.key}, value={self.value})"
@classmethod
def add_or_update(self, session: Session, toolkit_id: int, user_id: int, organisation_id: int, key: str, value: Text = None):
oauth_tokens = session.query(OauthTokens).filter_by(toolkit_id=toolkit_id, user_id=user_id).first()
if oauth_tokens:
# Update existing oauth tokens
if value is not None:
oauth_tokens.value = value
else:
# Create new oauth tokens
oauth_tokens = OauthTokens(toolkit_id=toolkit_id, user_id=user_id, organisation_id=organisation_id, key=key, value=value)
session.add(oauth_tokens)
session.commit()
================================================
FILE: superagi/models/organisation.py
================================================
from sqlalchemy import Column, Integer, String
from superagi.helper.tool_helper import register_toolkits
from superagi.models.base_model import DBBaseModel
class Organisation(DBBaseModel):
"""
Model representing an organization.
Attributes:
id (Integer): The primary key of the organization.
name (String): The name of the organization.
description (String): The description of the organization.
"""
__tablename__ = 'organisations'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(String)
def __repr__(self):
"""
Returns a string representation of the Organisation object.
Returns:
str: String representation of the Organisation object.
"""
return f"Organisation(id={self.id}, name='{self.name}')"
@classmethod
def find_or_create_organisation(cls, session, user):
"""
Finds or creates an organization for the given user.
Args:
session: The database session.
user: The user object.
Returns:
Organisation: The found or created organization.
"""
if user.organisation_id is not None:
organisation = session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
return organisation
existing_organisation = session.query(Organisation).filter(
Organisation.name == "Default Organization - " + str(user.id)).first()
if existing_organisation is not None:
user.organisation_id = existing_organisation.id
session.commit()
return existing_organisation
new_organisation = Organisation(
name="Default Organization - " + str(user.id),
description="New default organiztaion",
)
session.add(new_organisation)
session.commit()
session.flush()
user.organisation_id = new_organisation.id
session.commit()
register_toolkits(session=session, organisation=new_organisation)
return new_organisation
================================================
FILE: superagi/models/project.py
================================================
from sqlalchemy import Column, Integer, String,ForeignKey
from superagi.models.base_model import DBBaseModel
class Project(DBBaseModel):
"""
Model representing a project.
Attributes:
id (Integer): The primary key of the project.
name (String): The name of the project.
organisation_id (Integer): The ID of the organization to which the project belongs.
description (String): The description of the project.
"""
__tablename__ = 'projects'
id = Column(Integer, primary_key=True)
name = Column(String)
organisation_id = Column(Integer)
description = Column(String)
def __repr__(self):
"""
Returns a string representation of the Project object.
Returns:
str: String representation of the Project object.
"""
return f"Project(id={self.id}, name='{self.name}')"
@classmethod
def find_or_create_default_project(cls, session, organisation_id):
"""
Finds or creates the default project for the given organization.
Args:
session: The database session.
organisation_id (int): The ID of the organization.
Returns:
Project: The found or created default project.
"""
project = session.query(Project).filter(Project.organisation_id == organisation_id, Project.name == "Default Project").first()
if project is None:
default_project = Project(
name="Default Project",
organisation_id=organisation_id,
description="New Default Project"
)
session.add(default_project)
session.commit()
session.flush()
else:
default_project = project
return default_project
@classmethod
def find_by_org_id(cls, session, org_id: int):
project = session.query(Project).filter(Project.organisation_id == org_id).first()
return project
@classmethod
def find_by_id(cls, session, project_id: int):
project = session.query(Project).filter(Project.id == project_id).first()
return project
================================================
FILE: superagi/models/resource.py
================================================
from sqlalchemy import Column, Integer, String, Float, Text
from superagi.models.base_model import DBBaseModel
from sqlalchemy.orm import sessionmaker
class Resource(DBBaseModel):
"""
Model representing a resource.
Attributes:
id (Integer): The primary key of the resource.
name (String): The name of the resource.
storage_type (String): The storage type of the resource (FILESERVER, S3).
path (String): The path of the resource (required for S3 storage type).
size (Integer): The size of the resource.
type (String): The type of the resource (e.g., application/pdf).
channel (String): The channel of the resource (INPUT, OUTPUT).
agent_id (Integer): The ID of the agent associated with the resource.
agent_execution_id (Integer) : The ID of the agent execution corresponding to resource
"""
__tablename__ = 'resources'
id = Column(Integer, primary_key=True)
name = Column(String)
storage_type = Column(String) # FILESERVER,S3
path = Column(String) # need for S3
size = Column(Integer)
type = Column(String) # application/pdf etc
channel = Column(String) # INPUT,OUTPUT
agent_id = Column(Integer)
agent_execution_id = Column(Integer)
summary = Column(Text)
def __repr__(self):
"""
Returns a string representation of the Resource object.
Returns:
str: String representation of the Resource object.
"""
return f"Resource(id={self.id}, name='{self.name}', storage_type='{self.storage_type}', path='{self.path}, size='{self.size}', type='{self.type}', channel={self.channel}, agent_id={self.agent_id}, agent_execution_id={self.agent_execution_id})"
@staticmethod
def validate_resource_type(storage_type):
"""
Validates the resource type.
Args:
storage_type (str): The storage type to validate.
Raises:
InvalidResourceType: If the storage type is invalid.
"""
valid_types = ["FILE", "S3"]
if storage_type not in valid_types:
raise InvalidResourceType("Invalid resource type")
@classmethod
def find_by_run_ids(cls, session, run_ids: list):
db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids)).all()
return db_resources_arr
class InvalidResourceType(Exception):
"""Custom exception for invalid resource type"""
================================================
FILE: superagi/models/tool.py
================================================
from sqlalchemy import Column, Integer, String
from superagi.models.base_model import DBBaseModel
# from pydantic import BaseModel
class Tool(DBBaseModel):
"""
Model representing a tool.
Attributes:
id (Integer): The primary key of the tool.
name (String): The name of the tool.
folder_name (String): The folder name of the tool.
class_name (String): The class name of the tool.
file_name (String): The file name of the tool.
"""
__tablename__ = 'tools'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
description = Column(String)
folder_name = Column(String)
class_name = Column(String)
file_name = Column(String)
toolkit_id = Column(Integer)
def __repr__(self):
"""
Returns a string representation of the Tool object.
Returns:
str: String representation of the Tool object.
"""
return f"Tool(id={self.id}, name='{self.name}',description='{self.description}' folder_name='{self.folder_name}'," \
f" file_name = {self.file_name}, class_name='{self.class_name}, toolkit_id={self.toolkit_id}')"
def to_dict(self):
"""
Convert the Tool instance to a dictionary.
Returns:
dict: A dictionary representation of the Tool instance.
"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"folder_name": self.folder_name,
"class_name": self.class_name,
"file_name": self.file_name,
"toolkit_id": self.toolkit_id
}
@staticmethod
def add_or_update(session, tool_name: str, description: str, folder_name: str, class_name: str, file_name: str,
toolkit_id: int):
# Check if a record with the given tool name already exists inside a toolkit
tool = session.query(Tool).filter_by(name=tool_name,
toolkit_id=toolkit_id).first()
if tool is not None:
# Update the attributes of the existing tool record
tool.folder_name = folder_name
tool.class_name = class_name
tool.file_name = file_name
tool.description = description
else:
# Create a new tool record
tool = Tool(name=tool_name, description=description, folder_name=folder_name, class_name=class_name,
file_name=file_name,
toolkit_id=toolkit_id)
session.add(tool)
session.commit()
session.flush()
return tool
@staticmethod
def delete_tool(session, tool_name):
tool = session.query(Tool).filter(Tool.name == tool_name).first()
if tool:
session.delete(tool)
session.commit()
session.flush()
@classmethod
def convert_tool_names_to_ids(cls, db, tool_names):
"""
Converts a list of tool names to their corresponding IDs.
Args:
db: The database session.
tool_names (list): List of tool names.
Returns:
list: List of tool IDs.
"""
tools = db.session.query(Tool).filter(Tool.name.in_(tool_names)).all()
return [tool.id for tool in tools]
@classmethod
def convert_tool_ids_to_names(cls, db, tool_ids):
"""
Converts a list of tool IDs to their corresponding names.
Args:
db: The database session.
tool_ids (list): List of tool IDs.
Returns:
list: List of tool names.
"""
tools = db.session.query(Tool).filter(Tool.id.in_(tool_ids)).all()
return [str(tool.name) for tool in tools]
@classmethod
def get_invalid_tools(cls, tool_ids, session):
invalid_tool_ids = []
for tool_id in tool_ids:
tool = session.query(Tool).get(tool_id)
if tool is None:
invalid_tool_ids.append(tool_id)
return invalid_tool_ids
@classmethod
def get_toolkit_tools(cls, session, toolkit_id : int):
return session.query(Tool).filter(Tool.toolkit_id == toolkit_id).all()
================================================
FILE: superagi/models/tool_config.py
================================================
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.orm import Session, sessionmaker
from superagi.types.key_type import ToolConfigKeyType
from superagi.models.base_model import DBBaseModel
from superagi.helper.encyption_helper import encrypt_data
import json
import yaml
class ToolConfig(DBBaseModel):
"""
Model representing a tool configuration.
Attributes:
id (Integer): The primary key of the tool configuration.
key (String): The key of the tool configuration.
value (String): The value of the tool configuration.
toolkit_id (Integer): The identifier of the associated toolkit.
key_type (String): the type of key used.
is_secret (Boolean): Whether the tool configuration is a secret.
is_required (Boolean): Whether the tool configuration is a required field.
"""
__tablename__ = 'tool_configs'
id = Column(Integer, primary_key=True)
key = Column(String)
value = Column(String)
toolkit_id = Column(Integer)
key_type = Column(String)
is_secret = Column(Boolean)
is_required = Column(Boolean)
def __repr__(self):
return f"ToolConfig(id={self.id}, key='{self.key}', value='{self.value}, toolkit_id={self.toolkit_id}')"
def to_dict(self):
return {
'id': self.id,
'key': self.key,
'value': self.value,
'toolkit_id': {self.toolkit_id},
'key_type': self.key_type,
'is_secret': self.is_secret,
'is_required': self.is_required
}
def to_json(self):
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
return cls(
id=data['id'],
key=data['key'],
value=data['value'],
toolkit_id=data['toolkit_id'],
key_type=data['key_type'],
is_secret=data['is_secret'],
is_required=data['is_required']
)
@staticmethod
def add_or_update(session: Session, toolkit_id: int, key: str, value: str = None, key_type: str = None, is_secret: bool = False, is_required: bool = False):
tool_config = session.query(ToolConfig).filter_by(toolkit_id=toolkit_id, key=key).first()
if tool_config:
# Update existing tool config
if value is not None:
tool_config.value = (value)
if is_required is None:
tool_config.is_required = False
elif isinstance(is_required, bool):
tool_config.is_required = is_required
else:
raise ValueError("is_required should be a boolean value")
if is_secret is None:
tool_config.is_secret = False
elif isinstance(is_secret, bool):
tool_config.is_secret = is_secret
else:
raise ValueError("is_secret should be a boolean value")
if key_type is None:
tool_config.key_type = ToolConfigKeyType.STRING.value
elif isinstance(key_type,ToolConfigKeyType):
tool_config.key_type = key_type.value
else:
tool_config.key_type = key_type
else:
# Create new tool config
if key_type is None:
key_type = ToolConfigKeyType.STRING.value
if isinstance(key_type,ToolConfigKeyType):
key_type = key_type.value
tool_config = ToolConfig(toolkit_id=toolkit_id, key=key, value=value, key_type=key_type, is_secret=is_secret, is_required=is_required)
session.add(tool_config)
session.commit()
@classmethod
def get_toolkit_tool_config(cls, session: Session, toolkit_id: int):
return session.query(ToolConfig).filter_by(toolkit_id=toolkit_id).all()
================================================
FILE: superagi/models/toolkit.py
================================================
import json
import requests
from sqlalchemy import Column, Integer, String, Boolean
from superagi.models.base_model import DBBaseModel
from superagi.models.tool import Tool
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class Toolkit(DBBaseModel):
"""
ToolKit - Used to group tools together
Attributes:
id(int) : id of the tool kit
name(str) : name of the tool kit
description(str) : description of the tool kit
show_toolkit(boolean) : indicates whether the tool kit should be shown based on the count of tools in the toolkit
organisation_id(int) : org id of the to which tool config is related
tool_code_link(str) : stores Github link for toolkit
"""
__tablename__ = 'toolkits'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(String)
show_toolkit = Column(Boolean)
organisation_id = Column(Integer)
tool_code_link = Column(String)
def __repr__(self):
return f"ToolKit(id={self.id}, name='{self.name}', description='{self.description}', " \
f"show_toolkit={self.show_toolkit}," \
f"organisation_id = {self.organisation_id}"
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'description': self.description,
'show_toolkit': self.show_toolkit,
'organisation_id': self.organisation_id
}
def to_json(self):
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
return cls(
id=data['id'],
name=data['name'],
description=data['description'],
show_toolkit=data['show_toolkit'],
organisation_id=data['organisation_id']
)
@staticmethod
def add_or_update(session, name, description, show_toolkit, organisation_id, tool_code_link):
# Check if the toolkit exists
toolkit = session.query(Toolkit).filter(Toolkit.name == name,
Toolkit.organisation_id == organisation_id).first()
if toolkit:
# Update the existing toolkit
toolkit.name = name
toolkit.description = description
toolkit.show_toolkit = show_toolkit
toolkit.organisation_id = organisation_id
toolkit.tool_code_link = tool_code_link
else:
# Create a new toolkit
toolkit = Toolkit(
name=name,
description=description,
show_toolkit=show_toolkit,
organisation_id=organisation_id,
tool_code_link=tool_code_link
)
session.add(toolkit)
session.commit()
session.flush()
return toolkit
@classmethod
def fetch_marketplace_list(cls, page):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/toolkits/marketplace/list/{str(page)}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def fetch_marketplace_detail(cls, search_str, toolkit_name):
headers = {'Content-Type': 'application/json'}
search_str = search_str.replace(' ', '%20')
toolkit_name = toolkit_name.replace(' ', '%20')
response = requests.get(
marketplace_url + f"/toolkits/marketplace/{search_str}/{toolkit_name}",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return None
@staticmethod
def get_toolkit_from_name(session, toolkit_name, organisation):
toolkit = session.query(Toolkit).filter_by(name=toolkit_name, organisation_id=organisation.id).first()
if toolkit:
return toolkit
return None
@classmethod
def get_toolkit_installed_details(cls, session, marketplace_toolkits, organisation):
installed_toolkits = session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
for toolkit in marketplace_toolkits:
if toolkit['name'] in [installed_toolkit.name for installed_toolkit in installed_toolkits]:
toolkit["is_installed"] = True
else:
toolkit["is_installed"] = False
return marketplace_toolkits
@classmethod
def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids):
agent_toolkit_tools = []
for toolkit_id in toolkit_ids:
toolkit_tools = session.query(Tool).filter(Tool.toolkit_id == toolkit_id).all()
for tool in toolkit_tools:
tool = session.query(Tool).filter(Tool.id == tool.id).first()
if tool is not None:
agent_toolkit_tools.append(tool.id)
return agent_toolkit_tools
@classmethod
def get_tool_and_toolkit_arr(cls, session, organisation_id :int,agent_config_tools_arr: list):
from superagi.models.tool import Tool
toolkits_arr= set()
tools_arr= set()
for tool_obj in agent_config_tools_arr:
toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip(), Toolkit.organisation_id == organisation_id).first()
if toolkit is None:
raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")
toolkits_arr.add(toolkit.id)
if tool_obj.get("tools"):
for tool_name_str in tool_obj["tools"]:
tool_db_obj = session.query(Tool).filter(Tool.name == tool_name_str.strip(),
Tool.toolkit_id == toolkit.id).first()
if tool_db_obj is None:
raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.")
tools_arr.add(tool_db_obj.id)
else:
tools=Tool.get_toolkit_tools(session, toolkit.id)
for tool_db_obj in tools:
tools_arr.add(tool_db_obj.id)
return list(tools_arr)
================================================
FILE: superagi/models/types/__init__.py
================================================
================================================
FILE: superagi/models/types/agent_config.py
================================================
from typing import Union
from pydantic import BaseModel
class AgentConfig(BaseModel):
agent_id: int
key: str
value: Union[str, list]
def __repr__(self):
return f"AgentConfiguration(id={self.id}, key={self.key}, value={self.value})"
================================================
FILE: superagi/models/types/login_request.py
================================================
from pydantic import BaseModel
class LoginRequest(BaseModel):
email: str
password: str
================================================
FILE: superagi/models/types/validate_llm_api_key_request.py
================================================
from pydantic import BaseModel
class ValidateAPIKeyRequest(BaseModel):
model_source: str
model_api_key: str
================================================
FILE: superagi/models/user.py
================================================
from sqlalchemy import Column, Integer, String
from superagi.models.base_model import DBBaseModel
# from pydantic import BaseModel
class User(DBBaseModel):
"""
Model representing a user.
Attributes:
id (Integer): The primary key of the user.
name (String): The name of the user.
email (String): The email of the user.
password (String): The password of the user.
organisation_id (Integer): The ID of the associated organisation.
"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
email = Column(String, unique=True)
password = Column(String)
organisation_id = Column(Integer)
first_login_source = Column(String)
def __repr__(self):
"""
Returns a string representation of the User object.
Returns:
str: String representation of the User object.
"""
return f"User(id={self.id}, name='{self.name}', email='{self.email}', password='{self.password}'," \
f"organisation_id={self.organisation_id}, first_login_source={self.first_login_source})"
================================================
FILE: superagi/models/vector_db_configs.py
================================================
from sqlalchemy import Column, Integer, Text, String
from superagi.models.base_model import DBBaseModel
class VectordbConfigs(DBBaseModel):
"""
Vector db related configurations like api_key, environment, and url are stored here
Attributes:
id (int): The unique identifier of the vector db configuration.
vector_db_id (int): The identifier of the associated vector db.
key (str): The key of the configuration setting.
value (str): The value of the configuration setting.
"""
__tablename__ = 'vector_db_configs'
id = Column(Integer, primary_key=True, autoincrement=True)
vector_db_id = Column(Integer)
key = Column(String)
value = Column(Text)
def __repr__(self):
"""
Returns a string representation of the Agent Configuration object.
Returns:
str: String representation of the Agent Configuration.
"""
return f"VectorConfiguration(id={self.id}, key={self.key}, value={self.value})"
@classmethod
def get_vector_db_config_from_db_id(cls, session, vector_db_id):
vector_db_configs = session.query(VectordbConfigs).filter(VectordbConfigs.vector_db_id == vector_db_id).all()
config_data = {}
for config in vector_db_configs:
config_data[config.key] = config.value
return config_data
@classmethod
def add_vector_db_config(cls, session, vector_db_id, db_creds):
for key, value in db_creds.items():
vector_db_config = VectordbConfigs(vector_db_id=vector_db_id, key=key, value=value)
session.add(vector_db_config)
session.commit()
@classmethod
def delete_vector_db_configs(cls, session, vector_db_id):
session.query(VectordbConfigs).filter(VectordbConfigs.vector_db_id == vector_db_id).delete()
session.commit()
================================================
FILE: superagi/models/vector_db_indices.py
================================================
from __future__ import annotations
from sqlalchemy import Column, Integer, String
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
class VectordbIndices(DBBaseModel):
"""
Represents an vector db index.
Attributes:
id (int): The unique identifier of the index/collection also referred to as class in Weaviate.
name (str): The name of the index/collection.
vector_db_id (int): The identifier of the associated vector db.
"""
__tablename__ = 'vector_db_indices'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
vector_db_id = Column(Integer)
dimensions = Column(Integer)
state = Column(String)
def __repr__(self):
"""
Returns a string representation of the Vector db index object.
Returns:
str: String representation of the Vector db index.
"""
return f"VectordbIndices(id={self.id}, name='{self.name}', vector_db_id={self.vector_db_id}, dimensions={self.dimensions}, state={self.state})"
@classmethod
def get_vector_index_from_id(cls, session, vector_db_index_id):
vector_db_index = session.query(VectordbIndices).filter(VectordbIndices.id == vector_db_index_id).first()
return vector_db_index
@classmethod
def get_vector_indices_from_vectordb(cls, session, vector_db_id):
vector_indices = session.query(VectordbIndices).filter(VectordbIndices.vector_db_id == vector_db_id).all()
return vector_indices
@classmethod
def delete_vector_db_index(cls, session, vector_index_id):
session.query(VectordbIndices).filter(VectordbIndices.id == vector_index_id).delete()
session.commit()
@classmethod
def add_vector_index(cls, session, index_name, vector_db_id, state, dimensions = None): #will be none only in the case of weaviate
vector_index = VectordbIndices(name=index_name, vector_db_id=vector_db_id, dimensions=dimensions, state=state)
session.add(vector_index)
session.commit()
@classmethod
def update_vector_index_state(cls, session, index_id, state):
vector_index = session.query(VectordbIndices).filter(VectordbIndices.id == index_id).first()
vector_index.state = state
session.commit()
================================================
FILE: superagi/models/vector_dbs.py
================================================
from __future__ import annotations
import requests
from sqlalchemy import Column, Integer, String
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
class Vectordbs(DBBaseModel):
"""
Represents an vector db entity.
Attributes:
id (int): The unique identifier of the agent.
name (str): The name of the database.
db_type (str): The name of the db agent.
organisation_id (int): The identifier of the associated organisation.
"""
__tablename__ = 'vector_dbs'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String)
db_type = Column(String)
organisation_id = Column(Integer)
def __repr__(self):
"""
Returns a string representation of the Vector db object.
Returns:
str: String representation of the Vector db.
"""
return f"Vector(id={self.id}, name='{self.name}', db_type='{self.db_type}' organisation_id={self.organisation_id}, updated_at={self.updated_at})"
@classmethod
def get_vector_db_from_id(cls, session, vector_db_id):
vector_db = session.query(Vectordbs).filter(Vectordbs.id == vector_db_id).first()
return vector_db
@classmethod
def fetch_marketplace_list(cls):
headers = {'Content-Type': 'application/json'}
response = requests.get(
marketplace_url + f"/vector_dbs/marketplace/list",
headers=headers, timeout=10)
if response.status_code == 200:
return response.json()
else:
return []
@classmethod
def get_vector_db_from_organisation(cls, session, organisation):
vector_db_list = session.query(Vectordbs).filter(Vectordbs.organisation_id == organisation.id).all()
return vector_db_list
@classmethod
def add_vector_db(cls, session, name, db_type, organisation):
vector_db = Vectordbs(name=name, db_type=db_type, organisation_id=organisation.id)
session.add(vector_db)
session.commit()
return vector_db
@classmethod
def delete_vector_db(cls, session, vector_db_id):
session.query(Vectordbs).filter(Vectordbs.id == vector_db_id).delete()
session.commit()
================================================
FILE: superagi/models/webhook_events.py
================================================
from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey
from sqlalchemy.orm import relationship
from superagi.models.base_model import DBBaseModel
from superagi.models.agent_execution import AgentExecution
class WebhookEvents(DBBaseModel):
"""
Attributes:
Methods:
"""
__tablename__ = 'webhook_events'
id = Column(Integer, primary_key=True)
agent_id=Column(Integer)
run_id = Column(Integer)
event = Column(String)
status = Column(String)
errors= Column(Text)
================================================
FILE: superagi/models/webhooks.py
================================================
from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey,JSON
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import JSONB
from superagi.models.base_model import DBBaseModel
from superagi.models.agent_execution import AgentExecution
class Webhooks(DBBaseModel):
"""
Attributes:
Methods:
"""
__tablename__ = 'webhooks'
id = Column(Integer, primary_key=True)
name=Column(String)
org_id = Column(Integer)
url = Column(String)
headers=Column(JSON)
is_deleted=Column(Boolean)
filters=Column(JSON)
================================================
FILE: superagi/models/workflows/__init__.py
================================================
================================================
FILE: superagi/models/workflows/agent_workflow.py
================================================
import json
from sqlalchemy import Column, Integer, String, Text
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.base_model import DBBaseModel
class AgentWorkflow(DBBaseModel):
"""
Agent workflows which runs part of each agent execution step
Attributes:
id (int): The unique identifier of the agent workflow.
name (str): The name of the agent workflow.
description (str): The description of the agent workflow.
"""
__tablename__ = 'agent_workflows'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentWorkflow object.
Returns:
str: String representation of the AgentWorkflow.
"""
return f"AgentWorkflow(id={self.id}, name='{self.name}', " \
f"description='{self.description}')"
def to_dict(self):
"""
Converts the AgentWorkflow object to a dictionary.
Returns:
dict: Dictionary representation of the AgentWorkflow.
"""
return {
'id': self.id,
'name': self.name,
'description': self.description
}
def to_json(self):
"""
Converts the AgentWorkflow object to a JSON string.
Returns:
str: JSON string representation of the AgentWorkflow.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentWorkflow object from a JSON string.
Args:
json_data (str): JSON string representing the AgentWorkflow.
Returns:
AgentWorkflow: AgentWorkflow object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
name=data['name'],
description=data['description']
)
@classmethod
def fetch_trigger_step_id(cls, session, workflow_id):
"""
Fetches the trigger step ID of the specified agent workflow.
Args:
session: The session object used for database operations.
workflow_id (int): The ID of the agent workflow.
Returns:
int: The ID of the trigger step.
"""
trigger_step = session.query(AgentWorkflowStep).filter(AgentWorkflowStep.agent_workflow_id == workflow_id,
AgentWorkflowStep.step_type == 'TRIGGER').first()
return trigger_step
@classmethod
def find_by_id(cls, session, id: int):
"""Create or find an agent workflow by name."""
return session.query(AgentWorkflow).filter(AgentWorkflow.id == id).first()
@classmethod
def find_by_name(cls, session, name: str):
"""Create or find an agent workflow by name."""
return session.query(AgentWorkflow).filter(AgentWorkflow.name == name).first()
@classmethod
def find_or_create_by_name(cls, session, name: str, description: str):
"""Create or find an agent workflow by name."""
agent_workflow = session.query(AgentWorkflow).filter(AgentWorkflow.name == name).first()
if agent_workflow is None:
agent_workflow = AgentWorkflow(name=name, description=description)
session.add(agent_workflow)
session.commit()
return agent_workflow
================================================
FILE: superagi/models/workflows/agent_workflow_step.py
================================================
import json
from sqlalchemy import Column, Integer, String
from sqlalchemy.dialects.postgresql import JSONB
from superagi.lib.logger import logger
from superagi.models.base_model import DBBaseModel
from superagi.models.workflows.agent_workflow_step_tool import AgentWorkflowStepTool
from superagi.models.workflows.agent_workflow_step_wait import AgentWorkflowStepWait
from superagi.models.workflows.iteration_workflow import IterationWorkflow
class AgentWorkflowStep(DBBaseModel):
"""
Step of an agent workflow
Attributes:
id (int): The unique identifier of the agent workflow step.
agent_workflow_id (int): The ID of the agent workflow to which this step belongs.
unique_id (str): The unique identifier of the step.
step_type (str): The type of the step (TRIGGER, NORMAL).
action_type (str): The type of the action (TOOL, ITERATION_WORKFLOW, LLM).
action_reference_id: Reference id of the tool/iteration workflow/llm
next_steps: Next steps output and step id.
"""
__tablename__ = 'agent_workflow_steps'
id = Column(Integer, primary_key=True)
agent_workflow_id = Column(Integer)
unique_id = Column(String)
step_type = Column(String) # TRIGGER, NORMAL
action_type = Column(String) # TOOL, ITERATION_WORKFLOW, LLM, WAIT_STEP
action_reference_id = Column(Integer) # id of the action
next_steps = Column(JSONB) # edge_ref_id, response, step_id
def __repr__(self):
"""
Returns a string representation of the AgentWorkflowStep object.
Returns:
str: String representation of the AgentWorkflowStep.
"""
return f"AgentWorkflowStep(id={self.id}, status='{self.agent_workflow_id}', " \
f"prompt='{self.unique_id}', agent_id={self.step_type}, action_type={self.action_type}, " \
f"action_reference_id={self.action_reference_id}, next_steps={self.next_steps})"
def to_dict(self):
"""
Converts the AgentWorkflowStep object to a dictionary.
Returns:
dict: Dictionary representation of the AgentWorkflowStep.
"""
return {
'id': self.id,
'agent_workflow_id': self.agent_workflow_id,
'unique_id': self.unique_id,
'step_type': self.step_type,
'next_steps': self.next_steps,
'action_type': self.action_type,
'action_reference_id': self.action_reference_id
}
def to_json(self):
"""
Converts the AgentWorkflowStep object to a JSON string.
Returns:
str: JSON string representation of the AgentWorkflowStep.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentWorkflowStep object from a JSON string.
Args:
json_data (str): JSON string representing the AgentWorkflowStep.
Returns:
AgentWorkflowStep: AgentWorkflowStep object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
agent_workflow_id=data['agent_workflow_id'],
unique_id=data['unique_id'],
step_type=data['step_type'],
action_type=data['action_type'],
action_reference_id=data['action_reference_id'],
next_steps=data['next_steps'],
)
@classmethod
def find_by_unique_id(cls, session, unique_id: str):
""" Adds a workflows step in the next_steps column"""
return session.query(AgentWorkflowStep).filter(AgentWorkflowStep.unique_id == unique_id).first()
@classmethod
def find_by_id(cls, session, step_id: int):
""" Find the workflow step by id"""
return session.query(AgentWorkflowStep).filter(AgentWorkflowStep.id == step_id).first()
@classmethod
def find_or_create_tool_workflow_step(cls, session, agent_workflow_id: int, unique_id: str,
tool_name: str, input_instruction: str,
output_instruction: str = "", step_type="NORMAL",
history_enabled: bool = True, completion_prompt: str = None):
""" Find or create a tool workflow step
Args:
session: db session
agent_workflow_id: id of the agent workflow
unique_id: unique id of the step
tool_name: name of the tool
input_instruction: input instruction of the tool
output_instruction: output instruction of the tool
step_type: type of the step
history_enabled: whether to enable history for the step
completion_prompt: completion prompt in the llm
Returns:
AgentWorkflowStep.
"""
workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.agent_workflow_id == agent_workflow_id, AgentWorkflowStep.unique_id == unique_id).first()
if completion_prompt is None:
completion_prompt = f"Respond with only valid JSON conforming to the given json schema. Response should contain tool name and tool arguments to achieve the given instruction."
step_tool = AgentWorkflowStepTool.find_or_create_tool(session, unique_id, tool_name,
input_instruction, output_instruction,
history_enabled, completion_prompt)
if workflow_step is None:
workflow_step = AgentWorkflowStep(unique_id=unique_id, step_type=step_type,
agent_workflow_id=agent_workflow_id)
session.add(workflow_step)
session.commit()
workflow_step.step_type = step_type
workflow_step.agent_workflow_id = agent_workflow_id
workflow_step.action_reference_id = step_tool.id
workflow_step.action_type = "TOOL"
workflow_step.next_steps = []
workflow_step.completion_prompt = completion_prompt
session.commit()
return workflow_step
@classmethod
def find_or_create_wait_workflow_step(cls, session, agent_workflow_id: int, unique_id: str,
wait_description: str, delay: int, step_type="NORMAL"):
""" Find or create a wait workflow step"""
logger.info("Finding or creating wait step")
workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.agent_workflow_id == agent_workflow_id, AgentWorkflowStep.unique_id == unique_id).first()
step_wait = AgentWorkflowStepWait.find_or_create_wait(session=session, step_unique_id=unique_id,
description=wait_description, delay=delay)
if workflow_step is None:
workflow_step = AgentWorkflowStep(unique_id=unique_id, step_type=step_type,
agent_workflow_id=agent_workflow_id)
session.add(workflow_step)
session.commit()
workflow_step.step_type = step_type
workflow_step.agent_workflow_id = agent_workflow_id
workflow_step.action_reference_id = step_wait.id
workflow_step.action_type = "WAIT_STEP"
workflow_step.next_steps = []
session.commit()
return workflow_step
@classmethod
def find_or_create_iteration_workflow_step(cls, session, agent_workflow_id: int, unique_id: str,
iteration_workflow_name: str, step_type="NORMAL"):
""" Find or create a iteration workflow step
Args:
session: db session
agent_workflow_id: id of the agent workflow
unique_id: unique id of the step
iteration_workflow_name: name of the iteration workflow
step_type: type of the step
Returns:
AgentWorkflowStep.
"""
workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.agent_workflow_id == agent_workflow_id, AgentWorkflowStep.unique_id == unique_id).first()
iteration_workflow = IterationWorkflow.find_workflow_by_name(session, iteration_workflow_name)
if workflow_step is None:
workflow_step = AgentWorkflowStep(unique_id=unique_id, step_type=step_type,
agent_workflow_id=agent_workflow_id)
session.add(workflow_step)
session.commit()
workflow_step.step_type = step_type
workflow_step.agent_workflow_id = agent_workflow_id
workflow_step.action_reference_id = iteration_workflow.id
workflow_step.action_type = "ITERATION_WORKFLOW"
workflow_step.next_steps = []
session.commit()
return workflow_step
@classmethod
def add_next_workflow_step(cls, session, current_agent_step_id: int, next_step_id: int,
step_response: str = "default"):
""" Add Next workflow steps in the next_steps column
Args:
session: db session
current_agent_step_id: id of the current agent step
next_step_id: id of the next agent step
step_response: response of the current step
"""
next_unique_id = "-1"
if next_step_id != -1:
next_workflow_step = AgentWorkflowStep.find_by_id(session, next_step_id)
next_unique_id = next_workflow_step.unique_id
current_step = session.query(AgentWorkflowStep).filter(AgentWorkflowStep.id == current_agent_step_id).first()
next_steps = json.loads(json.dumps(current_step.next_steps))
existing_steps = [step for step in next_steps if step["step_id"] == next_unique_id]
if existing_steps:
existing_steps[0]["step_response"] = step_response
current_step.next_steps = next_steps
else:
next_steps.append({"step_response": str(step_response), "step_id": str(next_unique_id)})
current_step.next_steps = next_steps
session.commit()
return current_step
@classmethod
def fetch_default_next_step(cls, session, current_agent_step_id: int):
""" Fetches the default next step
Args:
session: db session
current_agent_step_id: id of the current agent step
"""
current_step = AgentWorkflowStep.find_by_id(session, current_agent_step_id)
next_steps = current_step.next_steps
default_steps = [step for step in next_steps if step["step_response"] == "default"]
if default_steps:
return AgentWorkflowStep.find_by_unique_id(session, default_steps[0]["step_id"])
return None
@classmethod
def fetch_next_step(cls, session, current_agent_step_id: int, step_response: str):
""" Fetch the next step based on the step response
Args:
session: db session
current_agent_step_id: id of the current agent step
step_response: response of the current step
"""
current_step = AgentWorkflowStep.find_by_id(session, current_agent_step_id)
next_steps = current_step.next_steps
matching_steps = [step for step in next_steps if str(step["step_response"]).lower() == step_response.lower()]
if matching_steps:
if str(matching_steps[0]["step_id"]) == "-1":
return "COMPLETE"
return AgentWorkflowStep.find_by_unique_id(session, matching_steps[0]["step_id"])
logger.info(f"Could not find next step for step_id: {current_agent_step_id} and step_response: {step_response}")
default_steps = [step for step in next_steps if str(step["step_response"]).lower() == "default"]
if default_steps:
if str(default_steps[0]["step_id"]) == "-1":
return "COMPLETE"
return AgentWorkflowStep.find_by_unique_id(session, default_steps[0]["step_id"])
return None
================================================
FILE: superagi/models/workflows/agent_workflow_step_tool.py
================================================
import json
from sqlalchemy import Column, Integer, String, Text, Boolean
from sqlalchemy.dialects.postgresql import JSONB
from superagi.models.base_model import DBBaseModel
class AgentWorkflowStepTool(DBBaseModel):
"""
Step of an agent workflow
Attributes:
id (int): The unique identifier of the agent workflow step
tool_name (str): Tool name
input_instruction (str): Input Instruction to the tool
output_instruction (str): Output Instruction to the tool
history_enabled: whether history enabled in the step
completion_prompt: completion prompt in the llm conversations
"""
__tablename__ = 'agent_workflow_step_tools'
id = Column(Integer, primary_key=True)
tool_name = Column(String)
unique_id = Column(String)
input_instruction = Column(Text)
output_instruction = Column(Text)
history_enabled = Column(Boolean)
completion_prompt = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentWorkflowStep object.
Returns:
str: String representation of the AgentWorkflowStep.
"""
return f"AgentWorkflowStep(id={self.id}, " \
f"prompt='{self.tool_name}', agent_id={self.tool_instruction})"
def to_dict(self):
"""
Converts the AgentWorkflowStep object to a dictionary.
Returns:
dict: Dictionary representation of the AgentWorkflowStep.
"""
return {
'id': self.id,
'tool_name': self.tool_name,
'input_instruction': self.input_instruction,
'output_instruction': self.output_instruction,
'history_enabled': self.history_enabled,
'completion_prompt': self.completion_prompt,
}
def to_json(self):
"""
Converts the AgentWorkflowStep object to a JSON string.
Returns:
str: JSON string representation of the AgentWorkflowStep.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentWorkflowStep object from a JSON string.
Args:
json_data (str): JSON string representing the AgentWorkflowStep.
Returns:
AgentWorkflowStep: AgentWorkflowStep object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
tool_name=data['tool_name'],
input_instruction=data['input_instruction'],
output_instruction=data['output_instruction'],
history_enabled=data['history_enabled'],
completion_prompt=data['completion_prompt'],
)
@classmethod
def find_by_id(cls, session, step_id: int):
return session.query(AgentWorkflowStepTool).filter(AgentWorkflowStepTool.id == step_id).first()
@classmethod
def find_or_create_tool(cls, session, step_unique_id: str, tool_name: str,
input_instruction: str, output_instruction: str,
history_enabled: bool = False, completion_prompt: str = None):
"""
Finds or creates a tool in the database.
Args:
session (Session): SQLAlchemy session object.
step_unique_id (str): Unique ID of the step.
tool_name (str): Name of the tool.
input_instruction (str): Tool input instructions.
output_instruction (str): Tool output instructions.
history_enabled (bool): Whether history is enabled for the tool.
completion_prompt (str): Completion prompt for the tool.
Returns:
AgentWorkflowStepTool: The AgentWorkflowStepTool object.
"""
unique_id = f"{step_unique_id}_{tool_name}"
tool = session.query(AgentWorkflowStepTool).filter_by(
unique_id=unique_id
).first()
if tool is None:
tool = AgentWorkflowStepTool(tool_name=tool_name, unique_id=unique_id,
input_instruction=input_instruction,
output_instruction=output_instruction,
history_enabled=history_enabled,
completion_prompt=completion_prompt)
session.add(tool)
else:
tool.tool_name = tool_name
tool.input_instruction = input_instruction
tool.output_instruction = output_instruction
tool.history_enabled = history_enabled
tool.completion_prompt = completion_prompt
session.commit()
return tool
================================================
FILE: superagi/models/workflows/agent_workflow_step_wait.py
================================================
import json
from sqlalchemy import Column, Integer, String, DateTime
from superagi.models.base_model import DBBaseModel
class AgentWorkflowStepWait(DBBaseModel):
"""
Step for a Agent Workflow to wait
Attributes:
id (int): The unique identifier of the wait block step.
name (str): The name of the wait block step.
description (str): The description of the wait block step.
delay (int): The delay time in seconds.
wait_begin_time (DateTime): The start time of the wait block.
"""
__tablename__ = 'agent_workflow_step_waits'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(String)
unique_id = Column(String)
delay = Column(Integer) # Delay is stored in seconds
wait_begin_time = Column(DateTime)
status = Column(String) # 'PENDING', 'WAITING', 'COMPLETED'
def __repr__(self):
"""
Returns a string representation of the WaitBlockStep object.
Returns:
str: String representation of the WaitBlockStep.
"""
return f"WaitBlockStep(id={self.id}, name='{self.name}', delay='{self.delay}', " \
f"wait_begin_time='{self.wait_begin_time}'"
def to_dict(self):
"""
Converts the WaitBlockStep object to a dictionary.
Returns:
dict: Dictionary representation of the WaitBlockStep.
"""
return {
'id': self.id,
'name': self.name,
'delay': self.delay,
'wait_begin_time': self.wait_begin_time
}
def to_json(self):
"""
Converts the WaitBlockStep object to a JSON string.
Returns:
str: JSON string representation of the WaitBlockStep.
"""
return json.dumps(self.to_dict())
@classmethod
def find_by_id(cls, session, step_id: int):
return session.query(AgentWorkflowStepWait).filter(AgentWorkflowStepWait.id == step_id).first()
@classmethod
def find_or_create_wait(cls, session, step_unique_id: str, description: str, delay: int):
unique_id = f"{step_unique_id}_wait"
wait = session.query(AgentWorkflowStepWait).filter(AgentWorkflowStepWait.unique_id == unique_id).first()
if wait is None:
wait = AgentWorkflowStepWait(
unique_id=unique_id,
name=unique_id,
delay=delay,
description=description,
status='PENDING'
)
session.add(wait)
else:
wait.delay = delay
wait.description = description
wait.status = 'PENDING'
session.commit()
session.flush()
return wait
================================================
FILE: superagi/models/workflows/iteration_workflow.py
================================================
import json
from sqlalchemy import Column, Integer, String, Text, Boolean
from superagi.models.base_model import DBBaseModel
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
class IterationWorkflow(DBBaseModel):
"""
Agent workflows which runs part of each agent execution step
Attributes:
id (int): The unique identifier of the agent workflow.
name (str): The name of the agent workflow.
description (str): The description of the agent workflow.
"""
__tablename__ = 'iteration_workflows'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(Text)
has_task_queue = Column(Boolean, default=False)
def __repr__(self):
"""
Returns a string representation of the AgentWorkflow object.
Returns:
str: String representation of the AgentWorkflow.
"""
return f"AgentWorkflow(id={self.id}, name='{self.name}', " \
f"description='{self.description}')"
def to_dict(self):
"""
Converts the AgentWorkflow object to a dictionary.
Returns:
dict: Dictionary representation of the AgentWorkflow.
"""
return {
'id': self.id,
'name': self.name,
'description': self.description
}
def to_json(self):
"""
Converts the AgentWorkflow object to a JSON string.
Returns:
str: JSON string representation of the AgentWorkflow.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentWorkflow object from a JSON string.
Args:
json_data (str): JSON string representing the AgentWorkflow.
Returns:
AgentWorkflow: AgentWorkflow object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
name=data['name'],
description=data['description']
)
@classmethod
def fetch_trigger_step_id(cls, session, workflow_id):
"""
Fetches the trigger step ID of the specified iteration workflow.
Args:
session: The session object used for database operations.
workflow_id (int): The ID of the agent workflow.
Returns:
int: The ID of the trigger step.
"""
trigger_step = session.query(IterationWorkflowStep).filter(
IterationWorkflowStep.iteration_workflow_id == workflow_id,
IterationWorkflowStep.step_type == 'TRIGGER').first()
return trigger_step
@classmethod
def find_workflow_by_name(cls, session, name: str):
"""
Finds an IterationWorkflow by name.
Args:
session (Session): SQLAlchemy session object.
name (str): Name of the AgentWorkflow.
Returns:
AgentWorkflow: AgentWorkflow object with the given name.
"""
return session.query(IterationWorkflow).filter(IterationWorkflow.name == name).first()
@classmethod
def find_or_create_by_name(cls, session, name: str, description: str, has_task_queue: bool = False):
"""
Finds an IterationWorkflow by name or creates it if it does not exist.
Args:
session (Session): SQLAlchemy session object.
name (str): Name of the AgentWorkflow.
description (str): Description of the AgentWorkflow.
"""
iteration_workflow = session.query(IterationWorkflow).filter(
IterationWorkflow.name == name).first()
if iteration_workflow is None:
iteration_workflow = IterationWorkflow(name=name, description=description)
session.add(iteration_workflow)
session.commit()
iteration_workflow.has_task_queue = has_task_queue
session.commit()
return iteration_workflow
@classmethod
def find_by_id(cls, session, id: int):
""" Find the workflow step by id"""
return session.query(IterationWorkflow).filter(IterationWorkflow.id == id).first()
================================================
FILE: superagi/models/workflows/iteration_workflow_step.py
================================================
import json
from sqlalchemy import Column, Integer, String, Text, Boolean
from sqlalchemy.dialects.postgresql import JSONB
from superagi.models.base_model import DBBaseModel
class IterationWorkflowStep(DBBaseModel):
"""
Step of an iteration workflow
Attributes:
id (int): The unique identifier of the agent workflow step.
iteration_workflow_id (int): The ID of the agent workflow to which this step belongs.
unique_id (str): The unique identifier of the step.
prompt (str): The prompt for the step.
variables (str): The variables associated with the step.
output_type (str): The output type of the step.
step_type (str): The type of the step (TRIGGER, NORMAL).
next_step_id (int): The ID of the next step in the workflow.
history_enabled (bool): Indicates whether history is enabled for the step.
completion_prompt (str): The completion prompt for the step.
"""
__tablename__ = 'iteration_workflow_steps'
id = Column(Integer, primary_key=True)
iteration_workflow_id = Column(Integer)
unique_id = Column(String)
prompt = Column(Text)
variables = Column(Text)
output_type = Column(String)
step_type = Column(String) # TRIGGER, NORMAL
next_step_id = Column(Integer)
history_enabled = Column(Boolean)
completion_prompt = Column(Text)
def __repr__(self):
"""
Returns a string representation of the AgentWorkflowStep object.
Returns:
str: String representation of the AgentWorkflowStep.
"""
return f"AgentWorkflowStep(id={self.id}, status='{self.next_step_id}', " \
f"prompt='{self.prompt}'"
def to_dict(self):
"""
Converts the AgentWorkflowStep object to a dictionary.
Returns:
dict: Dictionary representation of the AgentWorkflowStep.
"""
return {
'id': self.id,
'next_step_id': self.next_step_id,
'agent_id': self.agent_id,
'prompt': self.prompt
}
def to_json(self):
"""
Converts the AgentWorkflowStep object to a JSON string.
Returns:
str: JSON string representation of the AgentWorkflowStep.
"""
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_data):
"""
Creates an AgentWorkflowStep object from a JSON string.
Args:
json_data (str): JSON string representing the AgentWorkflowStep.
Returns:
AgentWorkflowStep: AgentWorkflowStep object created from the JSON string.
"""
data = json.loads(json_data)
return cls(
id=data['id'],
prompt=data['prompt'],
agent_id=data['agent_id'],
next_step_id=data['next_step_id']
)
@classmethod
def find_by_id(cls, session, step_id: int):
return session.query(IterationWorkflowStep).filter(IterationWorkflowStep.id == step_id).first()
@classmethod
def find_or_create_step(self, session, iteration_workflow_id: int, unique_id: str,
prompt: str, variables: str, step_type: str, output_type: str,
completion_prompt: str = "", history_enabled: bool = False):
workflow_step = session.query(IterationWorkflowStep).filter(IterationWorkflowStep.unique_id == unique_id).first()
if workflow_step is None:
workflow_step = IterationWorkflowStep(unique_id=unique_id)
session.add(workflow_step)
session.commit()
workflow_step.prompt = prompt
workflow_step.variables = variables
workflow_step.step_type = step_type
workflow_step.output_type = output_type
workflow_step.iteration_workflow_id = iteration_workflow_id
workflow_step.next_step_id = -1
workflow_step.history_enabled = history_enabled
if completion_prompt:
workflow_step.completion_prompt = completion_prompt
session.commit()
return workflow_step
================================================
FILE: superagi/resource_manager/__init__.py
================================================
================================================
FILE: superagi/resource_manager/file_manager.py
================================================
import csv
from sqlalchemy.orm import Session
from superagi.config.config import get_config
import os
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.types.storage_types import StorageType
class FileManager:
def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id
def write_binary_file(self, file_name: str, data):
if self.agent_id is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name,
Agent.get_agent_from_id(self.session,
self.agent_id),
AgentExecution.get_agent_execution_from_id(
self.session,
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
try:
with open(final_path, mode="wb") as img:
img.write(data)
img.close()
self.write_to_s3(file_name, final_path)
logger.info(f"Binary {file_name} saved successfully")
return f"Binary {file_name} saved successfully"
except Exception as err:
return f"Error write_binary_file: {err}"
def write_to_s3(self, file_name, final_path):
with open(final_path, 'rb') as img:
resource = ResourceHelper.make_written_file_resource(file_name=file_name,
agent=Agent.get_agent_from_id(self.session,
self.agent_id),
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
self.agent_execution_id),
session=self.session)
if resource.storage_type == StorageType.S3.value:
s3_helper = S3Helper()
s3_helper.upload_file(img, path=resource.path)
def write_file(self, file_name: str, content):
if self.agent_id is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name,
agent=Agent.get_agent_from_id(self.session,
self.agent_id),
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
try:
with open(final_path, mode="w") as file:
file.write(content)
file.close()
self.write_to_s3(file_name, final_path)
logger.info(f"{file_name} - File written successfully")
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_file: {err}"
def write_csv_file(self, file_name: str, csv_data):
if self.agent_id is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name,
agent=Agent.get_agent_from_id(self.session,
self.agent_id),
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
try:
with open(final_path, mode="w", newline="") as file:
writer = csv.writer(file, lineterminator="\n")
writer.writerows(csv_data)
self.write_to_s3(file_name, final_path)
logger.info(f"{file_name} - File written successfully")
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_csv_file: {err}"
def get_agent_resource_path(self, file_name: str):
return ResourceHelper.get_agent_write_resource_path(file_name, agent=Agent.get_agent_from_id(self.session,
self.agent_id),
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
self.agent_execution_id))
def read_file(self, file_name: str):
if self.agent_id is not None:
final_path = self.get_agent_resource_path(file_name)
else:
final_path = ResourceHelper.get_resource_path(file_name)
try:
with open(final_path, mode="r") as file:
content = file.read()
logger.info(f"{file_name} - File read successfully")
return content
except Exception as err:
return f"Error while reading file {file_name}: {err}"
def get_files(self):
"""
Gets all file names generated by the CodingTool.
Returns:
A list of file names.
"""
if self.agent_id is not None:
final_path = self.get_agent_resource_path("")
else:
final_path = ResourceHelper.get_resource_path("")
try:
# List all files in the directory
files = os.listdir(final_path)
except Exception as err:
logger.error(f"Error while accessing files in {final_path}: {err}")
files = []
return files
================================================
FILE: superagi/resource_manager/llama_document_summary.py
================================================
import os
from llama_index.indices.response import ResponseMode
from llama_index.schema import Document
from superagi.config.config import get_config
class LlamaDocumentSummary:
def __init__(self, model_name=get_config("RESOURCES_SUMMARY_MODEL_NAME", "gpt-3.5-turbo"), model_source="OpenAi", model_api_key: str = None):
self.model_name = model_name
self.model_api_key = model_api_key
self.model_source = model_source
def generate_summary_of_document(self, documents: list[Document]):
"""
Generates summary of the documents
:param documents: list of Document objects
:return: summary of the documents
"""
if documents is None or not documents:
return
from llama_index import LLMPredictor, ServiceContext, ResponseSynthesizer, DocumentSummaryIndex
os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY", "") or self.model_api_key
llm_predictor_chatgpt = LLMPredictor(llm=self._build_llm())
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024)
response_synthesizer = ResponseSynthesizer.from_args(response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True)
doc_summary_index = DocumentSummaryIndex.from_documents(
documents=documents,
service_context=service_context,
response_synthesizer=response_synthesizer
)
return doc_summary_index.get_document_summary(documents[0].doc_id)
def generate_summary_of_texts(self, texts: list[str]):
"""
Generates summary of the texts
:param texts: list of texts
:return: summary of the texts
"""
from llama_index import Document
if texts is not None and len(texts) > 0:
documents = [Document(doc_id=f"doc_id_{i}", text=text) for i, text in enumerate(texts)]
return self.generate_summary_of_document(documents)
raise ValueError("texts must be provided")
def _build_llm(self):
"""
Builds the LLM model
:return: LLM model object
"""
open_ai_models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
if self.model_name in open_ai_models:
from langchain.chat_models import ChatOpenAI
openai_api_key = get_config("OPENAI_API_KEY") or self.model_api_key
return ChatOpenAI(temperature=0, model_name=self.model_name,
openai_api_key=openai_api_key)
raise Exception(f"Model name {self.model_name} not supported for document summary")
================================================
FILE: superagi/resource_manager/llama_vector_store_factory.py
================================================
from llama_index.vector_stores.types import VectorStore
from superagi.config.config import get_config
from superagi.types.vector_store_types import VectorStoreType
class LlamaVectorStoreFactory:
"""
Factory class to create vector stores based on the vector_store_name
:param vector_store_name: VectorStoreType
:param index_name: str
:return: VectorStore object
"""
def __init__(self, vector_store_name: VectorStoreType, index_name: str):
self.vector_store_name = vector_store_name
self.index_name = index_name
def get_vector_store(self) -> VectorStore:
"""
Returns the vector store based on the vector_store_name
:return: VectorStore object
"""
if self.vector_store_name == VectorStoreType.PINECONE:
from llama_index.vector_stores import PineconeVectorStore
return PineconeVectorStore(self.index_name)
if self.vector_store_name == VectorStoreType.REDIS:
redis_url = get_config("REDIS_VECTOR_STORE_URL") or "redis://super__redis:6379"
from llama_index.vector_stores import RedisVectorStore
return RedisVectorStore(
index_name=self.index_name,
redis_url=redis_url,
metadata_fields=["agent_id", "resource_id"]
)
if self.vector_store_name == VectorStoreType.CHROMA:
from llama_index.vector_stores import ChromaVectorStore
import chromadb
from chromadb.config import Settings
chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
chroma_port = get_config("CHROMA_PORT") or 8000
chroma_client = chromadb.Client(
Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
chroma_server_http_port=chroma_port))
chroma_collection = chroma_client.get_or_create_collection(self.index_name)
return ChromaVectorStore(chroma_collection)
if self.vector_store_name == VectorStoreType.QDRANT:
from llama_index.vector_stores import QdrantVectorStore
qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
qdrant_port = get_config("QDRANT_PORT") or 6333
from qdrant_client import QdrantClient
qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
return QdrantVectorStore(client=qdrant_client, collection_name=self.index_name)
raise ValueError(str(self.vector_store_name) + " vector store is not supported yet.")
================================================
FILE: superagi/resource_manager/resource_manager.py
================================================
import os
from llama_index import SimpleDirectoryReader
from sqlalchemy.orm import Session
from superagi.config.config import get_config
from superagi.helper.resource_helper import ResourceHelper
from superagi.lib.logger import logger
from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
from superagi.types.model_source_types import ModelSourceType
from superagi.types.vector_store_types import VectorStoreType
from superagi.models.agent import Agent
class ResourceManager:
"""
Resource Manager handles creation of resources and saving them to the vector store.
:param agent_id: The agent id to use when saving resources to the vector store.
"""
def __init__(self, agent_id: str = None):
self.agent_id = agent_id
def create_llama_document(self, file_path: str):
"""
Creates a document index from a given file path.
:param file_path: The file path to create the document index from.
:return: A list of documents.
"""
if file_path is None:
raise Exception("file_path must be provided")
if os.path.exists(file_path):
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
return documents
def create_llama_document_s3(self, file_path: str):
"""
Creates a document index from a given file path.
:param file_path: The file path to create the document index from.
:return: A list of documents.
"""
if file_path is None:
raise Exception("file_path must be provided")
temporary_file_path = ""
try:
import boto3
s3 = boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
bucket_name = get_config("BUCKET_NAME")
file = s3.get_object(Bucket=bucket_name, Key=file_path)
file_name = file_path.split("/")[-1]
save_directory = "/"
temporary_file_path = save_directory + file_name
with open(temporary_file_path, "wb") as f:
contents = file['Body'].read()
f.write(contents)
documents = SimpleDirectoryReader(input_files=[temporary_file_path]).load_data()
return documents
except Exception as e:
logger.error("superagi/resource_manager/resource_manager.py - create_llama_document_s3 threw : ", e)
finally:
if os.path.exists(temporary_file_path):
os.remove(temporary_file_path)
def save_document_to_vector_store(self, documents: list, resource_id: str, mode_api_key: str = None,
model_source: str = ""):
"""
Saves a document to the vector store.
:param documents: The documents to save to the vector store.
:param resource_id: The resource id to use when saving the documents to the vector store.
:param mode_api_key: The mode api key to use when creating embedding to the vector store.
"""
from llama_index import VectorStoreIndex, StorageContext
if ModelSourceType.GooglePalm.value in model_source or ModelSourceType.Replicate.value in model_source:
logger.info("Resource embedding not supported for Google Palm..")
return
import openai
openai.api_key = get_config("OPENAI_API_KEY") or mode_api_key
os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY", "") or mode_api_key
for docs in documents:
if docs.metadata is None:
docs.metadata = {}
docs.metadata["agent_id"] = str(self.agent_id)
docs.metadata["resource_id"] = resource_id
vector_store = None
storage_context = None
vector_store_name = VectorStoreType.get_vector_store_type(get_config("RESOURCE_VECTOR_STORE") or "Redis")
vector_store_index_name = get_config("RESOURCE_VECTOR_STORE_INDEX_NAME") or "super-agent-index"
try:
vector_store = LlamaVectorStoreFactory(vector_store_name, vector_store_index_name).get_vector_store()
storage_context = StorageContext.from_defaults(vector_store=vector_store)
except ValueError as e:
logger.error(f"Vector store not found{e}")
try:
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
index.set_index_id(f'Agent {self.agent_id}')
except Exception as e:
logger.error("save_document_to_vector_store - unable to create documents from vector", e)
# persisting the data in case of redis
if vector_store_name == VectorStoreType.REDIS:
vector_store.persist(persist_path="")
================================================
FILE: superagi/resource_manager/resource_summary.py
================================================
from datetime import datetime
import logging
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.configuration import Configuration
from superagi.models.resource import Resource
from superagi.resource_manager.llama_document_summary import LlamaDocumentSummary
from superagi.resource_manager.resource_manager import ResourceManager
from superagi.types.model_source_types import ModelSourceType
class ResourceSummarizer:
"""Class to summarize a resource."""
def __init__(self, session, agent_id: int, model: str):
self.session = session
self.agent_id = agent_id
self.organisation_id = self.__get_organisation_id()
self.model = model
def __get_organisation_id(self):
agent = self.session.query(Agent).filter(Agent.id == self.agent_id).first()
organisation = agent.get_agent_organisation(self.session)
return organisation.id
def __get_model_api_key(self):
return Configuration.fetch_configurations(self.session, self.organisation_id, "model_api_key", self.model)
def __get_model_source(self):
return Configuration.fetch_configurations(self.session, self.organisation_id, "model_source", self.model)
def add_to_vector_store_and_create_summary(self, resource_id: int, documents: list):
"""
Add a file to the vector store and generate a summary for it.
Args:
agent_id (str): ID of the agent.
resource_id (int): ID of the resource.
openai_api_key (str): OpenAI API key.
documents (list): List of documents.
"""
model_api_key = self.__get_model_api_key()
try:
ResourceManager(str(self.agent_id)).save_document_to_vector_store(documents, str(resource_id), model_api_key,
self.__get_model_source())
except Exception as e:
logger.error("add_to_vector_store_and_create_summary: Unable to save document to vector store.", e)
def generate_agent_summary(self, generate_all: bool = False) -> str:
"""Generate a summary of all resources for an agent."""
agent_config_resource_summary = self.session.query(AgentConfiguration). \
filter(AgentConfiguration.agent_id == self.agent_id,
AgentConfiguration.key == "resource_summary").first()
resources = self.session.query(Resource).filter(Resource.agent_id == self.agent_id,
Resource.channel == 'INPUT').all()
if not resources:
return
resource_summary = " ".join([resource.name for resource in resources])
agent_last_resource = self.session.query(AgentConfiguration). \
filter(AgentConfiguration.agent_id == self.agent_id,
AgentConfiguration.key == "last_resource_time").first()
if agent_config_resource_summary is not None:
agent_config_resource_summary.value = resource_summary
else:
agent_config_resource_summary = AgentConfiguration(agent_id=self.agent_id, key="resource_summary",
value=resource_summary)
self.session.add(agent_config_resource_summary)
if agent_last_resource is not None:
agent_last_resource.value = str(resources[-1].updated_at)
else:
agent_last_resource = AgentConfiguration(agent_id=self.agent_id, key="last_resource_time",
value=str(resources[-1].updated_at))
self.session.add(agent_last_resource)
self.session.commit()
def fetch_or_create_agent_resource_summary(self, default_summary: str):
print(self.__get_model_source())
if ModelSourceType.GooglePalm.value in self.__get_model_source():
return
self.generate_agent_summary(generate_all=True)
agent_config_resource_summary = self.session.query(AgentConfiguration). \
filter(AgentConfiguration.agent_id == self.agent_id,
AgentConfiguration.key == "resource_summary").first()
resource_summary = agent_config_resource_summary.value if agent_config_resource_summary is not None else default_summary
return resource_summary
================================================
FILE: superagi/tool_manager.py
================================================
import os
from pathlib import Path
import requests
import zipfile
import json
def parse_github_url(github_url):
parts = github_url.split('/')
owner = parts[3]
repo = parts[4]
branch = "main"
return f"{owner}/{repo}/{branch}"
def download_tool(tool_url, target_folder):
parsed_url = parse_github_url(tool_url)
parts = parsed_url.split("/")
owner, repo, branch, path = parts[0], parts[1], parts[2], "/".join(parts[3:])
archive_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
response = requests.get(archive_url)
tool_zip_file_path = os.path.join(target_folder, 'tool.zip')
with open(tool_zip_file_path, 'wb') as f:
f.write(response.content)
with zipfile.ZipFile(tool_zip_file_path, 'r') as z:
members = [m for m in z.namelist() if m.startswith(f"{owner}-{repo}") and f"{path}" in m]
for member in members:
archive_folder = f"{owner}-{repo}"
target_name = member.replace(f"{archive_folder}/", "", 1)
# Skip the unique hash folder while extracting:
segments = target_name.split('/', 1)
if len(segments) > 1:
target_name = segments[1]
else:
continue
target_path = os.path.join(target_folder, target_name)
if not target_name:
continue
if member.endswith('/'):
os.makedirs(target_path, exist_ok=True)
else:
with open(target_path, 'wb') as outfile, z.open(member) as infile:
outfile.write(infile.read())
os.remove(tool_zip_file_path)
def download_marketplace_tool(tool_url, target_folder):
parsed_url = tool_url.split("/")
owner, repo = parsed_url[3], parsed_url[4]
archive_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/main"
response = requests.get(archive_url)
tool_zip_file_path = os.path.join(target_folder, 'tool.zip')
with open(tool_zip_file_path, 'wb') as f:
f.write(response.content)
with zipfile.ZipFile(tool_zip_file_path, 'r') as z:
for member in z.namelist():
archive_folder, target_name = member.split('/', 1)
target_name = os.path.join(target_folder, target_name)
if member.endswith('/'):
os.makedirs(target_name, exist_ok=True)
elif not target_name.endswith('.md'):
with open(target_name, 'wb') as outfile, z.open(member) as infile:
outfile.write(infile.read())
os.remove(tool_zip_file_path)
def get_marketplace_tool_links(repo_url):
folder_links = {}
api_url = f"https://api.github.com/repos/{repo_url}/contents"
response = requests.get(api_url)
contents = response.json()
for content in contents:
if content["type"] == "dir":
folder_name = content["name"]
folder_link = f"https://github.com/{repo_url}/tree/main/{folder_name}"
folder_links[folder_name] = folder_link
return folder_links
def update_tools_json(existing_tools_json_path, folder_links):
with open(existing_tools_json_path, "r") as file:
tools_data = json.load(file)
if "tools" not in tools_data:
tools_data["tools"] = {}
tools_data["tools"].update(folder_links)
with open(existing_tools_json_path, "w") as file:
json.dump(tools_data, file, indent=4)
def load_tools_config():
tool_config_path = str(Path(__file__).parent.parent)
with open(tool_config_path + "/tools.json", "r") as f:
config = json.load(f)
return config["tools"]
def load_marketplace_tools():
marketplace_url = "TransformerOptimus/SuperAGI-Tools"
tools_config_path = str(Path(__file__).parent.parent)
tools_json_path = tools_config_path + "/tools.json"
# Get folder links from the repository
marketplace_tool_urls = get_marketplace_tool_links(marketplace_url)
# Update existing tools.json file
update_tools_json(tools_json_path, marketplace_tool_urls)
def is_marketplace_url(url):
return url.startswith("https://github.com/TransformerOptimus/SuperAGI-Tools/tree")
def download_and_extract_tools():
tools_config = load_tools_config()
for tool_name, tool_url in tools_config.items():
if is_marketplace_url(tool_url):
tool_folder = os.path.join("superagi/tools/marketplace_tools")
if not os.path.exists(tool_folder):
os.makedirs(tool_folder)
download_marketplace_tool(tool_url, tool_folder)
else:
tool_folder = os.path.join("superagi/tools/external_tools", tool_name)
if not os.path.exists(tool_folder):
os.makedirs(tool_folder)
download_tool(tool_url, tool_folder)
if __name__ == "__main__":
load_marketplace_tools()
download_and_extract_tools()
================================================
FILE: superagi/tools/__init__.py
================================================
================================================
FILE: superagi/tools/apollo/__init__.py
================================================
================================================
FILE: superagi/tools/apollo/apollo_search.py
================================================
import json
from typing import Type
import requests
from pydantic import BaseModel, Field
from superagi.lib.logger import logger
from superagi.tools.base_tool import BaseTool
class ApolloSearchSchema(BaseModel):
person_titles: list[str] = Field(
...,
description="The titles of the people to search for.",
)
page: int = Field(
1,
description="The page of results to retrieve. Default value is 1.",
)
per_page: int = Field(
25,
description="The number of results to retrieve per page. Default value is 25.",
)
num_of_employees: list[int] = Field(
[],
description="The number of employees to filter by in format [start_range, end_range]. Default value is empty array.",
)
organization_domains: str = Field(
"",
description="The organization domains to search within. It is optional field.",
)
person_location: str = Field(
"",
description="Region country/state/city filter to search for. It is optional field.",
)
class ApolloSearchTool(BaseTool):
"""
Apollo Search tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "ApolloSearch"
description = (
"A tool for performing a Apollo search and extracting people data."
"Input should include API key, organization domains, page number, and person titles."
)
args_schema: Type[BaseModel] = ApolloSearchSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, person_titles: list[str], page: int = 1, per_page: int = 25, num_of_employees: list[int] = [],
person_location: str = "", organization_domains: str = "") -> str:
"""
Execute the Apollo search tool.
Args:
person_titles : The titles of the people to search for.
page : The page of results to retrieve.
num_of_employees : The number of employees to filter by in format [start_range, end_range]. It is optional.
person_location : Region country/state/city filter to search for. It is optional.
organization_domains : The organization domains to search within.
Returns:
People data from the Apollo search.
"""
people_data = self.apollo_search_results(page, per_page, person_titles,
num_of_employees, person_location, organization_domains)
logger.info(people_data)
people_list = []
if people_data and 'people' in people_data and len(people_data['people']) > 0:
for person in people_data['people']:
people_list.append({'first_name': person['first_name'],
'last_name': person['last_name'],
'name': person['name'],
'linkedin_url': person['linkedin_url'],
'email': person['email'],
'headline': person['headline'],
'title': person['title'],
})
return people_list
def apollo_search_results(self, page, per_page, person_titles, num_of_employees = [],
person_location = "", organization_domains = ""):
"""
Execute the Apollo search tool.
Args:
page : The page of results to retrieve.
person_titles : The titles of the people to search for.
num_of_employees : The number of employees to filter by in format [start_range, end_range]. It is optional.
person_location: Region country/state/city filter to search for. It is optional.
Returns:
People data from the Apollo search.
"""
url = "https://api.apollo.io/v1/mixed_people/search"
headers = {
"Content-Type": "application/json",
"Cache-Control": "no-cache"
}
data = {
"api_key": self.get_tool_config("APOLLO_SEARCH_KEY"),
"page": page,
"per_page": per_page,
"person_titles": person_titles,
"contact_email_status": ["verified"]
}
if organization_domains:
data["q_organization_domains"] = organization_domains
if num_of_employees:
if num_of_employees[1] == num_of_employees[0]:
data["num_of_employees"] = [str(num_of_employees[0]) + ","]
else:
data["num_of_employees"] = [str(num_of_employees[0]) + ","+ str(num_of_employees[1])]
if person_location:
data["person_locations"] = [person_location]
response = requests.post(url, headers=headers, data=json.dumps(data))
print(response)
if response.status_code == 200:
return response.json()
else:
return None
================================================
FILE: superagi/tools/apollo/apollo_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.apollo.apollo_search import ApolloSearchTool
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from superagi.types.key_type import ToolConfigKeyType
class ApolloToolkit(BaseToolkit, ABC):
name: str = "ApolloToolkit"
description: str = "Apollo Tool kit contains all tools related to apollo.io tasks"
def get_tools(self) -> List[BaseTool]:
return [ApolloSearchTool()]
def get_env_keys(self) -> List[str]:
return [ToolConfiguration(key="APOLLO_SEARCH_KEY", key_type=ToolConfigKeyType.STRING, is_required=True)]
================================================
FILE: superagi/tools/base_tool.py
================================================
from abc import abstractmethod
from functools import wraps
from inspect import signature
from typing import List
from typing import Optional, Type, Callable, Any, Union, Dict, Tuple
import yaml
from pydantic import BaseModel, create_model, validate_arguments, Extra
from superagi.models.tool_config import ToolConfig
from sqlalchemy import Column, Integer, String, Boolean
from superagi.types.key_type import ToolConfigKeyType
from superagi.config.config import get_config
class SchemaSettings:
"""Configuration for the pydantic model."""
extra = Extra.forbid
arbitrary_types_allowed = True
def extract_valid_parameters(
inferred_type: Type[BaseModel],
function: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_type.schema()["properties"]
valid_params = signature(function).parameters
return {param: schema[param] for param in valid_params if param != "run_manager"}
def _construct_model_subset(
model_name: str, original_model: BaseModel, required_fields: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field: (
original_model.__fields__[field].type_,
original_model.__fields__[field].default,
)
for field in required_fields
if field in original_model.__fields__
}
return create_model(model_name, **fields) # type: ignore
def create_function_schema(
schema_name: str,
function: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
validated = validate_arguments(function, config=SchemaSettings) # type: ignore
inferred_type = validated.model # type: ignore
if "run_manager" in inferred_type.__fields__:
del inferred_type.__fields__["run_manager"]
valid_parameters = extract_valid_parameters(inferred_type, function)
return _construct_model_subset(
f"{schema_name}Schema", inferred_type, list(valid_parameters)
)
class BaseToolkitConfiguration:
def __init__(self):
self.session = None
def get_tool_config(self, key: str):
# Default implementation of the tool configuration retrieval logic
with open("config.yaml") as file:
config = yaml.safe_load(file)
# Retrieve the value associated with the given key
return config.get(key)
class BaseTool(BaseModel):
name: str = None
description: str
args_schema: Type[BaseModel] = None
permission_required: bool = True
toolkit_config: BaseToolkitConfiguration = BaseToolkitConfiguration()
class Config:
arbitrary_types_allowed = True
@property
def args(self):
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
name = self.name
args_schema = create_function_schema(f"{name}Schema", self.execute)
return args_schema.schema()["properties"]
@abstractmethod
def _execute(self, *args: Any, **kwargs: Any):
pass
@property
def max_token_limit(self):
return int(get_config("MAX_TOOL_TOKEN_LIMIT", 600))
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def execute(
self,
tool_input: Union[str, Dict],
**kwargs: Any
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._execute(*tool_args, **tool_kwargs)
)
except (Exception, KeyboardInterrupt) as e:
raise e
return observation
@classmethod
def from_function(cls, func: Callable, args_schema: Type[BaseModel] = None):
if args_schema:
return cls(description=func.__doc__, args_schema=args_schema)
else:
return cls(description=func.__doc__)
def get_tool_config(self, key):
return self.toolkit_config.get_tool_config(key=key)
class FunctionalTool(BaseTool):
name: str = None
description: str
func: Callable
args_schema: Type[BaseModel] = None
@property
def args(self):
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
name = self.name
args_schema = create_function_schema(f"{name}Schema", self.execute)
return args_schema.schema()["properties"]
def _execute(self, *args: Any, **kwargs: Any):
return self.func(*args, kwargs)
@classmethod
def from_function(cls, func: Callable, args_schema: Type[BaseModel] = None):
if args_schema:
return cls(description=func.__doc__, args_schema=args_schema)
else:
return cls(description=func.__doc__)
def registerTool(cls):
cls.__registerTool__ = True
return cls
def tool(*args: Union[str, Callable], return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None) -> Callable:
def decorator(func: Callable) -> Callable:
nonlocal args_schema
tool_instance = FunctionalTool.from_function(func, args_schema)
@wraps(func)
def wrapper(*tool_args, **tool_kwargs):
if return_direct:
return tool_instance._exec(*tool_args, **tool_kwargs)
else:
return tool_instance
return wrapper
if len(args) == 1 and callable(args[0]):
return decorator(args[0])
else:
return decorator
class ToolConfiguration:
def __init__(self, key: str, key_type: str = None, is_required: bool = False, is_secret: bool = False):
self.key = key
if is_secret is None:
self.is_secret = False
elif isinstance(is_secret, bool):
self.is_secret = is_secret
else:
raise ValueError("is_secret should be a boolean value")
if is_required is None:
self.is_required = False
elif isinstance(is_required, bool):
self.is_required = is_required
else:
raise ValueError("is_required should be a boolean value")
if key_type is None:
self.key_type = ToolConfigKeyType.STRING
elif isinstance(key_type,ToolConfigKeyType):
self.key_type = key_type
else:
raise ValueError("key_type should be string/file/integer")
class BaseToolkit(BaseModel):
name: str
description: str
@abstractmethod
def get_tools(self) -> List[BaseTool]:
# Add file related tools object here
pass
@abstractmethod
def get_env_keys(self) -> List[str]:
# Add file related config keys here
pass
================================================
FILE: superagi/tools/code/README.MD
================================================
# SuperAGI Coding Tool
The robust SuperAGI Coding Tool lets help with their coding tasks like writing, reviewing, refactoring code, fixing bugs, and understanding programming concepts.
## 💡 Features
1. **Write Code:** With SuperAGI's Coding Tool, writing new code is a streamlined and effortless process, making your programming tasks much simpler.
2. **Review Code:** SuperAGI's Coding Tool allows comprehensive code reviews, ensuring your code maintains quality standards and adheres to best practices.
3. **Refactor Code:** Refactoring your code is a breeze with SuperAGI's Coding Tool, allowing you to improve your code structure without changing its functionality.
4. **Debugging:** The Coding Tool is equipped to identify and fix bugs efficiently, ensuring your code performs as intended.
5. **Concept Explanation:** This feature provides clear explanations for various programming concepts, enhancing your understanding and making complex coding problems easier to solve.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
You'll be able to use the Coding Tool on the fly once you have setup SuperAGI.
## Running SuperAGI Coding Tool
You can simply ask your agent to read or go through your coding files in the Resource Manager, and it'll be able to do any coding feature as mentioned above.
================================================
FILE: superagi/tools/code/__init__.py
================================================
================================================
FILE: superagi/tools/code/coding_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from superagi.tools.code.improve_code import ImproveCodeTool
from superagi.tools.code.write_code import CodingTool
from superagi.tools.code.write_spec import WriteSpecTool
from superagi.tools.code.write_test import WriteTestTool
class CodingToolkit(BaseToolkit, ABC):
name: str = "CodingToolkit"
description: str = "Coding Tool kit contains all tools related to coding tasks"
def get_tools(self) -> List[BaseTool]:
return [CodingTool(), WriteSpecTool(), WriteTestTool(), ImproveCodeTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return []
================================================
FILE: superagi/tools/code/improve_code.py
================================================
import re
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
class ImproveCodeSchema(BaseModel):
pass
class ImproveCodeTool(BaseTool):
"""
Used to improve the already generated code by reading the code from the files
Attributes:
llm: LLM used for code generation.
name : The name of the tool.
description : The description of the tool.
resource_manager: Manages the file resources.
"""
llm: Optional[BaseLlm] = None
agent_id: int = None
agent_execution_id: int = None
name = "ImproveCodeTool"
description = (
"This tool improves the generated code."
)
args_schema: Type[ImproveCodeSchema] = ImproveCodeSchema
resource_manager: Optional[FileManager] = None
tool_response_manager: Optional[ToolResponseQueryManager] = None
goals: List[str] = []
class Config:
arbitrary_types_allowed = True
def _execute(self) -> str:
"""
Execute the improve code tool.
Returns:
Improved code or error message.
"""
# Get all file names that the CodingTool has written
file_names = self.resource_manager.get_files()
logger.info(file_names)
# Loop through each file
for file_name in file_names:
if '.txt' not in file_name and '.sh' not in file_name and '.json' not in file_name:
# Read the file content
content = self.resource_manager.read_file(file_name)
# Generate a prompt from improve_code.txt
prompt = PromptReader.read_tools_prompt(__file__, "improve_code.txt")
# Combine the hint from the file, goals, and content
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{content}", content)
# Add the file content to the chat completion prompt
prompt = prompt + "\nOriginal Code:\n```\n" + content + "\n```"
# Use LLM to generate improved code
result = self.llm.chat_completion([{'role': 'system', 'content': prompt}])
if result is not None and 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
# Extract the response first
response = result.get('response')
if not response:
logger.info("RESPONSE NOT AVAILABLE")
# Now extract the choices from response
choices = response.get('choices')
if not choices:
logger.info("CHOICES NOT AVAILABLE")
# Now you can safely extract the message content
improved_content = choices[0]["message"]["content"]
# improved_content = result["messages"][0]["content"]
parsed_content = re.findall("```(?:\w*\n)?(.*?)```", improved_content, re.DOTALL)
parsed_content_code = "\n".join(parsed_content)
# Rewrite the file with the improved content
save_result = self.resource_manager.write_file(file_name, parsed_content_code)
if save_result.startswith("Error"):
return save_result
else:
continue
return f"All codes improved and saved successfully in: " + " ".join(file_names)
================================================
FILE: superagi/tools/code/prompts/generate_logic.txt
================================================
You typically always place distinct classes in separate files.
Always create a run.sh file which act as the entrypoint of the program, create it intellligently after analyzing the file types
For Python, always generate a suitable requirements.txt file.
For NodeJS, consistently produce an appropriate package.json file.
Always include a brief comment that describes the purpose of the function definition.
Attempt to provide comments that explain complicated logic.
Consistently adhere to best practices for the specified languages, ensuring code is defined as a package or project.
Preferred Python toolbelt:
- pytest
- dataclasses
================================================
FILE: superagi/tools/code/prompts/improve_code.txt
================================================
You are a super smart developer. You have been tasked with fixing and filling the function and classes where only the description of code is written without the actual code . There might be placeholders in the code you have to fill in.
You provide fully functioning, well formatted code with few comments, that works and has no bugs.
If the code is already correct and doesn't need change, just return the same code
However, make sure that you only return the improved code, without any additional content.
Please structure the improved code as follows:
```
CODE
```
Please return the full new code in same format as the original code
Don't write any explanation or description in your response other than the actual code
Your high-level goal is:
{goals}
The content of the file you need to improve is:
{content}
Only return the code and not any other line
To start, first analyze the existing code. Check for any function with missing logic inside it and fill the function.
Make sure, that not a single function is empty or contains just comments, there should be function logic inside it
Return fully completed functions by filling the placeholders
================================================
FILE: superagi/tools/code/prompts/write_code.txt
================================================
You are a super smart developer who practices good Development for writing code according to a specification.
Please note that the code should be fully functional. There should be no placeholder in functions or classes in any file.
Your high-level goal is:
{goals}
Coding task description:
{code_description}
{spec}
You will get instructions for code to write.
You need to write a detailed answer. Make sure all parts of the architecture are turned into code.
Think carefully about each step and make good choices to get it right. First, list the main classes,
functions, methods you'll use and a quick comment on their purpose.
Then you will output the content of each file including ALL code.
Each file must strictly follow a markdown code block format, where the following tokens must be replaced such that
FILENAME is the lowercase file name including the file extension,
[LANG] is the markup code block language for the code's language, and [CODE] is the code:
FILENAME
```[LANG]
[CODE]
```
You will start with the "entrypoint" file, then go to the ones that are imported by that file, and so on.
Follow a language and framework appropriate best practice file naming convention.
Make sure that files contain all imports, types etc. Make sure that code in different files are compatible with each other.
Ensure to implement all code, if you are unsure, write a plausible implementation.
Include module dependency or package manager dependency definition file.
Before you finish, double check that all parts of the architecture is present in the files.
================================================
FILE: superagi/tools/code/prompts/write_spec.txt
================================================
You are a super smart developer who has been asked to make a specification for a program.
Your high-level goal is:
{goals}
Please keep in mind the following when creating the specification:
1. Be super explicit about what the program should do, which features it should have, and give details about anything that might be unclear.
2. Lay out the names of the core classes, functions, methods that will be necessary, as well as a quick comment on their purpose.
3. List all non-standard dependencies that will have to be used.
Write a specification for the following task:
{task}
================================================
FILE: superagi/tools/code/prompts/write_test.txt
================================================
You are a super smart developer who practices Test Driven Development for writing tests according to a specification.
Your high-level goal is:
{goals}
Test Description:
{test_description}
{spec}
Test should follow the following format:
FILENAME is the lowercase file name including the file extension,
[LANG] is the markup code block language for the code's language, and [UNIT_TEST_CODE] is the code:
FILENAME
```[LANG]
[UNIT_TEST_CODE]
```
The tests should be as simple as possible, but still cover all the functionality described in the specification.
================================================
FILE: superagi/tools/code/write_code.py
================================================
import re
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from superagi.models.agent import Agent
class CodingSchema(BaseModel):
code_description: str = Field(
...,
description="Description of the coding task",
)
class CodingTool(BaseTool):
"""
Used to generate code.
Attributes:
llm: LLM used for code generation.
name : The name of tool.
description : The description of tool.
args_schema : The args schema.
goals : The goals.
resource_manager: Manages the file resources
"""
llm: Optional[BaseLlm] = None
agent_id: int = None
agent_execution_id: int = None
name = "CodingTool"
description = (
"You will get instructions for code to write. You will write a very long answer. "
"Make sure that every detail of the architecture is, in the end, implemented as code. "
"Think step by step and reason yourself to the right decisions to make sure we get it right. "
"You will first lay out the names of the core classes, functions, methods that will be necessary, "
"as well as a quick comment on their purpose. Then you will output the content of each file including each function and class and ALL code."
)
args_schema: Type[CodingSchema] = CodingSchema
goals: List[str] = []
resource_manager: Optional[FileManager] = None
tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, code_description: str) -> str:
"""
Execute the write_code tool.
Args:
code_description : The coding task description.
code_file_name: The name of the file where the generated codes will be saved.
Returns:
Generated code with where the code is being saved or error message.
"""
prompt = PromptReader.read_tools_prompt(__file__, "write_code.txt") + "\nUseful to know:\n" + PromptReader.read_tools_prompt(__file__, "generate_logic.txt")
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{code_description}", code_description)
spec_response = self.tool_response_manager.get_last_response("WriteSpecTool")
if spec_response != "":
prompt = prompt.replace("{spec}", "Use this specs for generating the code:\n" + spec_response)
logger.info(prompt)
messages = [{"role": "system", "content": prompt}]
organisation = Agent.find_org_by_agent_id(session=self.toolkit_config.session, agent_id=self.agent_id)
total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
token_limit = TokenCounter(session=self.toolkit_config.session, organisation_id=organisation.id).token_limit(self.llm.get_model())
result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
# Get all filenames and corresponding code blocks
regex = r"(\S+?)\n```\S*\n(.+?)```"
matches = re.finditer(regex, result["content"], re.DOTALL)
file_names = []
# Save each file
for match in matches:
# Get the filename
file_name = re.sub(r'[<>"|?*]', "", match.group(1))
if not file_name[0].isalnum():
file_name = file_name[1:-1]
# Get the code
code = match.group(2)
# Ensure file_name is not empty
if not file_name.strip():
continue
file_names.append(file_name)
save_result = self.resource_manager.write_file(file_name, code)
if save_result.startswith("Error"):
return save_result
# Get README contents and save
split_result = result["content"].split("```")
if split_result:
readme = split_result[0]
save_readme_result = self.resource_manager.write_file("README.md", readme)
if save_readme_result.startswith("Error"):
return save_readme_result
return result["content"] + "\n Codes generated and saved successfully in " + ", ".join(file_names)
================================================
FILE: superagi/tools/code/write_spec.py
================================================
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.models.agent import Agent
class WriteSpecSchema(BaseModel):
task_description: str = Field(
...,
description="Specification task description.",
)
spec_file_name: str = Field(
...,
description="Name of the file to write. Only include the file name. Don't include path."
)
class WriteSpecTool(BaseTool):
"""
Used to generate program specification.
Attributes:
llm: LLM used for specification generation.
name : The name of tool.
description : The description of tool.
args_schema : The args schema.
goals : The goals.
resource_manager: Manages the file resources
"""
llm: Optional[BaseLlm] = None
agent_id: int = None
agent_execution_id: int = None
name = "WriteSpecTool"
description = (
"A tool to write the spec of a program."
)
args_schema: Type[WriteSpecSchema] = WriteSpecSchema
goals: List[str] = []
resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, task_description: str, spec_file_name: str) -> str:
"""
Execute the write_spec tool.
Args:
task_description : The task description.
spec_file_name: The name of the file where the generated specification will be saved.
Returns:
Generated specification or error message.
"""
prompt = PromptReader.read_tools_prompt(__file__, "write_spec.txt")
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{task}", task_description)
messages = [{"role": "system", "content": prompt}]
organisation = Agent.find_org_by_agent_id(self.toolkit_config.session, agent_id=self.agent_id)
total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
token_limit = TokenCounter(session=self.toolkit_config.session, organisation_id=organisation.id).token_limit(self.llm.get_model())
result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
# Save the specification to a file
write_result = self.resource_manager.write_file(spec_file_name, result["content"])
if not write_result.startswith("Error"):
return result["content"] + "\nSpecification generated and saved successfully"
else:
return write_result
================================================
FILE: superagi/tools/code/write_test.py
================================================
import re
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from superagi.models.agent import Agent
class WriteTestSchema(BaseModel):
test_description: str = Field(
...,
description="Description of the testing task",
)
test_file_name: str = Field(
...,
description="Name of the file to write. Only include the file name. Don't include path."
)
class WriteTestTool(BaseTool):
"""
Used to generate unit tests based on the specification.
Attributes:
llm: LLM used for test generation.
name : The name of tool.
description : The description of tool.
args_schema : The args schema.
goals : The goals.
resource_manager: Manages the file resources
"""
llm: Optional[BaseLlm] = None
agent_id: int = None
agent_execution_id: int = None
name = "WriteTestTool"
description = (
"You are a super smart developer using Test Driven Development to write tests according to a specification.\n"
"Please generate tests based on the above specification. The tests should be as simple as possible, "
"but still cover all the functionality.\n"
"Write it in the file"
)
args_schema: Type[WriteTestSchema] = WriteTestSchema
goals: List[str] = []
resource_manager: Optional[FileManager] = None
tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, test_description: str, test_file_name: str) -> str:
"""
Execute the write_test tool.
Args:
test_description : The specification description.
test_file_name: The name of the file where the generated tests will be saved.
Returns:
Generated unit tests or error message.
"""
prompt = PromptReader.read_tools_prompt(__file__, "write_test.txt")
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{test_description}", test_description)
spec_response = self.tool_response_manager.get_last_response("WriteSpecTool")
if spec_response != "":
prompt = prompt.replace("{spec}",
"Please generate unit tests based on the following specification description:\n" + spec_response)
else:
spec_response = self.tool_response_manager.get_last_response()
if spec_response != "":
prompt = prompt.replace("{spec}",
"Please generate unit tests based on the following specification description:\n" + spec_response)
messages = [{"role": "system", "content": prompt}]
logger.info(prompt)
organisation = Agent.find_org_by_agent_id(self.toolkit_config.session, agent_id=self.agent_id)
total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
token_limit = TokenCounter(session=self.toolkit_config.session, organisation_id=organisation.id).token_limit(self.llm.get_model())
result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
regex = r"(\S+?)\n```\S*\n(.+?)```"
matches = re.finditer(regex, result["content"], re.DOTALL)
file_names = []
# Save each file
for match in matches:
# Get the filename
file_name = re.sub(r'[<>"|?*]', "", match.group(1))
code = match.group(2)
if not file_name.strip():
continue
file_names.append(file_name)
save_result = self.resource_manager.write_file(file_name, code)
if save_result.startswith("Error"):
return save_result
# Save the tests to a file
# save_result = self.resource_manager.write_file(test_file_name, code_content)
if not result["content"].startswith("Error"):
return result["content"] + " \n Tests generated and saved successfully in " + test_file_name
else:
return save_result
================================================
FILE: superagi/tools/duck_duck_go/README.md
================================================
# SuperAGI DuckDuckGo Search Tool
The SuperAGI DuckDuckGo Search Tool helps users perform a DuckDuckGo search and extract snippets and webpages.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
## Running SuperAGI DuckDuckGo Search Tool
You can simply ask your agent about latest information regarding anything in the world and your agent will be able to browse the internet to get that information for you.
================================================
FILE: superagi/tools/duck_duck_go/__init__.py
================================================
================================================
FILE: superagi/tools/duck_duck_go/duck_duck_go_search.py
================================================
import json
import requests
from typing import Type, Optional,Union
import time
from superagi.helper.error_handler import ErrorHandler
from superagi.lib.logger import logger
from pydantic import BaseModel, Field
from duckduckgo_search import DDGS
from itertools import islice
from superagi.helper.token_counter import TokenCounter
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
from superagi.helper.webpage_extractor import WebpageExtractor
#Const variables
DUCKDUCKGO_MAX_ATTEMPTS = 3
WEBPAGE_EXTRACTOR_MAX_ATTEMPTS=2
MAX_LINKS_TO_SCRAPE=3
NUM_RESULTS_TO_USE=10
class DuckDuckGoSearchSchema(BaseModel):
query: str = Field(
...,
description="The search query for duckduckgo search.",
)
class DuckDuckGoSearchTool(BaseTool):
"""
Duck Duck Go Search tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "DuckDuckGoSearch"
agent_id: int = None
agent_execution_id: int = None
description = (
"A tool for performing a DuckDuckGo search and extracting snippets and webpages."
"Input should be a search query."
)
args_schema: Type[DuckDuckGoSearchSchema] = DuckDuckGoSearchSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, query: str) -> tuple:
"""
Execute the DuckDuckGo search tool.
Args:
query : The query to search for.
Returns:
Search result summary along with related links
"""
search_results = self.get_raw_duckduckgo_results(query)
links=[]
for result in search_results:
links.append(result["href"])
webpages=self.get_content_from_url(links)
results=self.get_formatted_webpages(search_results,webpages) #array to store objects with keys :{"title":snippet , "body":webpage content, "links":link URL}
summary = self.summarise_result(query, results) #summarize the content gathered using the function
links = [result["links"] for result in results if len(result["links"]) > 0]
if len(links) > 0:
return summary + "\n\nLinks:\n" + "\n".join("- " + link for link in links[:3])
return summary
def get_formatted_webpages(self,search_results,webpages):
"""
Generate an array of formatted webpages which can be passed to the summarizer function (summarise_result).
Args:
search_results : The array of objects which were fetched by DuckDuckGo.
Returns:
Returns the result array which is an array of objects
"""
results=[] #array to store objects with keys :{"title":snippet , "body":webpage content, "links":link URL}
i = 0
for webpage in webpages:
results.append({"title": search_results[i]["title"], "body": webpage, "links": search_results[i]["href"]})
i += 1
if TokenCounter.count_text_tokens(json.dumps(results)) > 3000:
break
return results
def get_content_from_url(self,links):
"""
Generates a webpage array which stores the content fetched from the links
Args:
links : The array of URLs which were fetched by DuckDuckGo.
Returns:
Returns a webpage array which stores the content fetched from the links
"""
webpages=[] #webpages array for storing the contents extracted from the links
if links:
for i in range(0, MAX_LINKS_TO_SCRAPE): #using first 3 (Value of MAX_LINKS_TO_SCRAPE) links
time.sleep(3)
content = WebpageExtractor().extract_with_bs4(links[i]) #takes in the link and returns content extracted from Webpage extractor
max_length = len(' '.join(content.split(" ")[:500]))
content = content[:max_length] #formatting the content
attempts = 0
while content == "" and attempts < WEBPAGE_EXTRACTOR_MAX_ATTEMPTS:
attempts += 1
content = WebpageExtractor().extract_with_bs4(links[i])
content = content[:max_length]
webpages.append(content)
return webpages
def get_raw_duckduckgo_results(self,query):
"""
Gets raw search results from the duckduckgosearch python package
Args:
query : The query to search for.
Returns:
Returns raw search results from the duckduckgosearch python package
"""
search_results = []
attempts = 0
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
if not query: #checking if string is empty, if it is empty-> convert array to JSON object and return it;
return json.dumps(search_results)
results = DDGS().text(query) #text() method from DDGS takes in query (String) as input and returns the results
search_results = list(islice(results, NUM_RESULTS_TO_USE)) #gets first 10 results from results and stores them in search_results
if search_results: #if search result is populated,break as there is no need to attempt the search again
break
# time.sleep(1)
attempts += 1
return search_results
def summarise_result(self, query, snippets):
"""
Summarise the result of a DuckDuckGo search.
Args:
query : The query to search for.
snippets (list): A list of snippets from the search.
Returns:
A summary of the search result.
"""
summarize_prompt ="""Summarize the following text `{snippets}`
Write a concise or as descriptive as necessary and attempt to
answer the query: `{query}` as best as possible. Use markdown formatting for
longer responses."""
summarize_prompt = summarize_prompt.replace("{snippets}", str(snippets))
summarize_prompt = summarize_prompt.replace("{query}", query)
messages = [{"role": "system", "content": summarize_prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
return result["content"]
================================================
FILE: superagi/tools/duck_duck_go/duck_duck_go_search_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.duck_duck_go.duck_duck_go_search import DuckDuckGoSearchTool
from superagi.types.key_type import ToolConfigKeyType
from superagi.models.tool_config import ToolConfig
class DuckDuckGoToolkit(BaseToolkit, ABC):
name: str = "DuckDuckGo Search Toolkit"
description: str = "Toolkit containing tools for performing DuckDuckGo search and extracting snippets and webpages"
def get_tools(self) -> List[BaseTool]:
return [DuckDuckGoSearchTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
# Add more config keys specific to your project
]
================================================
FILE: superagi/tools/email/README.md
================================================
# SuperAGI Email Tool
The robust SuperAGI Email Tool lets users send and read emails while providing a foundation for other fascinating use cases.
## 💡 Features
1.**Read Emails:** With SuperAGI's Email Tool, you can effortlessly manage your inbox and ensure that you never overlook a critical detail.
2. **Send Emails:** SuperAGI's Email Tool uses its comprehensive language model capabilities to create personalised, context-aware emails, sparing you effort and time.
3. **Save Emails to Drafts Folder:** By allowing SuperAGI to develop email draughts that you can examine and modify before sending, you'll gain greater control and make sure your messages are tailored to your tastes.
4. **Send Emails with Attachments:** Send attachments in emails with ease to enrich and expand the scope of your message.
5. **Custom Email Signature:** Create a unique signature for each email you send to add a touch of customization and automation.
6. **Auto-Reply and Answer Questions:** Allow SuperAGI to read, analyse, and respond to incoming emails with precise answers to streamline your email responses.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
### 🔧 **Add Email configuration settings SuperAGI's Dashboard**

Add the following configuration in the Email Toolkit Page:
1. _Email address and password:_
- Set 'EMAIL_ADDRESS' to sender's email address
- Set 'EMAIL_PASSWORD' to your Password. If using Gmail, use App Password (Follow the steps given below to obtain your app password.)
2. _Provider-specific settings:_
- If not using Gmail, modify 'EMAIL_SMTP_HOST', 'EMAIL_SMTP_PORT', AND 'EMAIL_IMAP_HOST' according to your email service provider.
3. _Sending and Drafts:_
- You can set the EMAIL_DRAFT_MODE to "FALSE" if you'd like your email to be directly sent and "TRUE" if you'd like to save your emails in Draft.
- If you're setting Draft Mode to True, Make sure to add the draft folder for your email service provider to prevent it from being sent.
4. _Optional Settings:_
- Change the 'EMAIL_SIGNATURE' to your personalize signature.
## Obtain your App Password
To obtain App password for your Gmail Account follow the steps:
- Navigate to the link (https://myaccount.google.com/apppasswords)

- To get the App Password ensure that you have set up 2-Step Verification for your email address.
- Generate the password by creating a custom app

- Copy the password generated and use it for 'EMAIL_PASSWORD'
- Also make sure IMAP Access is enabled for your Gmail Address (Settings > See all settings > Forwarding and POP/IMAP > Enable IMAP)

## Running SuperAGI Email Tool
1. **Read an email**
By default SuperAGI's email tool reads last 10 emails from your inbox, to change the limit you can modify the default limit in read_email.py
2. **Send an email**
To send an email to a particular receiver, mention the receiver's ID in your goal. Email will be stored in drafts if in case receiver's email address is not mentioned.

3. **Send an email with attachment**
SuperAGI can send Emails with Attachments if you have uploaded the file in the Resource Manager, or if your file is in the Input or the Output of your SuperAGI Workspace.

```
================================================
FILE: superagi/tools/email/__init__.py
================================================
================================================
FILE: superagi/tools/email/email_toolkit.py
================================================
from abc import ABC
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from typing import Type, List
from superagi.tools.email.read_email import ReadEmailTool
from superagi.tools.email.send_email import SendEmailTool
from superagi.tools.email.send_email_attachment import SendEmailAttachmentTool
from superagi.types.key_type import ToolConfigKeyType
class EmailToolkit(BaseToolkit, ABC):
name: str = "Email Toolkit"
description: str = "Email Tool kit contains all tools related to sending email"
def get_tools(self) -> List[BaseTool]:
return [ReadEmailTool(), SendEmailTool(), SendEmailAttachmentTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="EMAIL_ADDRESS", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = False),
ToolConfiguration(key="EMAIL_PASSWORD", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=True),
ToolConfiguration(key="EMAIL_SIGNATURE", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=False),
ToolConfiguration(key="EMAIL_DRAFT_MODE", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False),
ToolConfiguration(key="EMAIL_DRAFT_FOLDER", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False),
ToolConfiguration(key="EMAIL_SMTP_HOST", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False),
ToolConfiguration(key="EMAIL_SMTP_PORT", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False),
ToolConfiguration(key="EMAIL_IMAP_SERVER", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False)
]
================================================
FILE: superagi/tools/email/read_email.py
================================================
import email
import json
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.imap_email import ImapEmail
from superagi.helper.read_email import ReadEmail
from superagi.helper.token_counter import TokenCounter
from superagi.tools.base_tool import BaseTool
class ReadEmailInput(BaseModel):
imap_folder: str = Field(..., description="Email folder to read from. default value is \"INBOX\"")
page: int = Field(...,
description="The index of the page result the function should resturn. Defaults to 0, the first page.")
limit: int = Field(..., description="Number of emails to fetch in one cycle. Defaults to 5.")
class ReadEmailTool(BaseTool):
"""
Read emails from an IMAP mailbox
Attributes:
name : The name of the tool.
description : The description of the tool.
args_schema : The args schema.
"""
name: str = "Read Email"
args_schema: Type[BaseModel] = ReadEmailInput
description: str = "Read emails from an IMAP mailbox"
def _execute(self, imap_folder: str = "INBOX", page: int = 0, limit: int = 5) -> str:
"""
Execute the read email tool.
Args:
imap_folder : The email folder to read from. Defaults to "INBOX".
page : The index of the page result the function should return. Defaults to 0, the first page.
limit : Number of emails to fetch in one cycle. Defaults to 5.
Returns:
email contents or error message.
"""
email_sender = self.get_tool_config('EMAIL_ADDRESS')
email_password = self.get_tool_config('EMAIL_PASSWORD')
if email_sender == "":
return "Error: Email Not Sent. Enter a valid Email Address."
if email_password == "":
return "Error: Email Not Sent. Enter a valid Email Password."
imap_server = self.get_tool_config('EMAIL_IMAP_SERVER')
conn = ImapEmail().imap_open(imap_folder, email_sender, email_password, imap_server)
status, messages = conn.select("INBOX")
num_of_messages = int(messages[0])
messages = []
for i in range(num_of_messages, num_of_messages - limit, -1):
res, msg = conn.fetch(str(i), "(RFC822)")
email_msg = {}
for response in msg:
self._process_message(email_msg, response)
messages.append(email_msg)
if TokenCounter.count_text_tokens(json.dumps(messages)) > self.max_token_limit:
break
conn.logout()
if not messages:
return f"There are no Email in your folder {imap_folder}"
else:
return messages
def _process_message(self, email_msg, response):
if isinstance(response, tuple):
msg = email.message_from_bytes(response[1])
email_msg["From"], email_msg["To"], email_msg["Date"], email_msg[
"Subject"] = ReadEmail().obtain_header(msg)
if msg.is_multipart():
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition"))
try:
body = part.get_payload(decode=True).decode()
except:
pass
if content_type == "text/plain" and "attachment" not in content_disposition:
email_msg["Message Body"] = ReadEmail().clean_email_body(body)
elif "attachment" in content_disposition:
ReadEmail().download_attachment(part, email_msg["Subject"])
else:
content_type = msg.get_content_type()
body = msg.get_payload(decode=True).decode()
if content_type == "text/plain":
email_msg["Message Body"] = ReadEmail().clean_email_body(body)
================================================
FILE: superagi/tools/email/send_email.py
================================================
import imaplib
import smtplib
import time
from email.message import EmailMessage
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.imap_email import ImapEmail
from superagi.tools.base_tool import BaseTool
class SendEmailInput(BaseModel):
to: str = Field(..., description="Email Address of the Receiver, default email address is 'example@example.com'")
subject: str = Field(..., description="Subject of the Email to be sent")
body: str = Field(..., description="Email Body to be sent. Escape special characters in the body. Do not add senders details and end it with Warm Regards without entering any name.")
class SendEmailTool(BaseTool):
"""
Send an Email tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "Send Email"
args_schema: Type[BaseModel] = SendEmailInput
description: str = "Send an Email"
def _execute(self, to: str, subject: str, body: str) -> str:
"""
Execute the send email tool.
Args:
to : The email address of the receiver.
subject : The subject of the email.
body : The body of the email.
Returns:
success or error message.
"""
email_sender = self.get_tool_config('EMAIL_ADDRESS')
email_password = self.get_tool_config('EMAIL_PASSWORD')
if email_sender is None or email_sender == "" or email_sender.isspace():
return "Error: Email Not Sent. Enter a valid Email Address."
if email_password is None or email_password == "" or email_password.isspace():
return "Error: Email Not Sent. Enter a valid Email Password."
message = EmailMessage()
message["Subject"] = subject
message["From"] = email_sender
message["To"] = to
signature = self.get_tool_config('EMAIL_SIGNATURE')
if signature:
body += f"\n{signature}"
message.set_content(body.replace('\\n', '\n'))
send_to_draft = self.get_tool_config('EMAIL_DRAFT_MODE') or "FALSE"
if send_to_draft.upper() == "TRUE":
send_to_draft = True
else:
send_to_draft = False
if send_to_draft:
draft_folder = self.get_tool_config('EMAIL_DRAFT_FOLDER') or "Drafts"
imap_server = self.get_tool_config('EMAIL_IMAP_SERVER')
conn = ImapEmail().imap_open(draft_folder, email_sender, email_password, imap_server)
conn.append(
draft_folder,
"",
imaplib.Time2Internaldate(time.time()),
str(message).encode("UTF-8")
)
return f"Email went to {draft_folder}"
if message["To"] == "example@example.com":
return "Error: Email Not Sent. Enter an Email Address."
else:
smtp_host = self.get_tool_config('EMAIL_SMTP_HOST')
smtp_port = self.get_tool_config('EMAIL_SMTP_PORT')
with smtplib.SMTP(smtp_host, smtp_port) as smtp:
smtp.ehlo()
smtp.starttls()
smtp.login(email_sender, email_password)
smtp.send_message(message)
smtp.quit()
return f"Email was sent to {to}"
================================================
FILE: superagi/tools/email/send_email_attachment.py
================================================
import imaplib
import mimetypes
import os
import smtplib
import time
from email.message import EmailMessage
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Type
from pydantic import BaseModel, Field
from superagi.config.config import get_config
from superagi.helper.imap_email import ImapEmail
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.tools.base_tool import BaseTool
from superagi.config.config import get_config
from superagi.types.storage_types import StorageType
class SendEmailAttachmentInput(BaseModel):
to: str = Field(..., description="Email Address of the Receiver, default email address is 'example@example.com'")
subject: str = Field(..., description="Subject of the Email to be sent")
body: str = Field(..., description="Email Body to be sent, Do not add senders details in the email body and end it with Warm Regards without entering any name.")
filename: str = Field(..., description="Name of the file to be sent as an Attachment with Email")
class SendEmailAttachmentTool(BaseTool):
"""
Send an Email with Attachment tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "Send Email with Attachment"
args_schema: Type[BaseModel] = SendEmailAttachmentInput
description: str = "Send an Email with a file attached to it"
agent_id: int = None
agent_execution_id: int = None
def _execute(self, to: str, subject: str, body: str, filename: str) -> str:
"""
Execute the send email tool with attachment.
Args:
to : The email address of the receiver.
subject : The subject of the email.
body : The body of the email.
filename : The name of the file to be sent as an attachment with the email.
Returns:
success or failure message
"""
final_path = ResourceHelper.get_agent_read_resource_path(file_name=filename,
agent=Agent.get_agent_from_id(
self.toolkit_config.session,
self.agent_id),
agent_execution=AgentExecution.get_agent_execution_from_id(
session=self.toolkit_config.session,
agent_execution_id=self.agent_execution_id)
)
ctype, encoding = mimetypes.guess_type(final_path)
if ctype is None or encoding is not None:
ctype = "application/octet-stream"
maintype, subtype = ctype.split("/", 1)
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
attachment_data = S3Helper().read_binary_from_s3(final_path)
else:
if final_path is None or not os.path.exists(final_path):
raise FileNotFoundError(f"File '{filename}' not found.")
with open(final_path, "rb") as file:
attachment_data = file.read()
attachment = MIMEApplication(attachment_data)
attachment.add_header('Content-Disposition', 'attachment', filename=final_path.split('/')[-1])
return self.send_email_with_attachment(to, subject, body, attachment)
def send_email_with_attachment(self, to, subject, body, attachment) -> str:
"""
Send an email with attachment.
Args:
to : The email address of the receiver.
subject : The subject of the email.
body : The body of the email.
attachment : The data of the file to be sent as an attachment with the email.
Returns:
"""
email_sender = self.get_tool_config('EMAIL_ADDRESS')
email_password = self.get_tool_config('EMAIL_PASSWORD')
if email_sender is None or email_sender == "" or email_sender.isspace():
return "Error: Email Not Sent. Enter a valid Email Address."
if email_password is None or email_password == "" or email_password.isspace():
return "Error: Email Not Sent. Enter a valid Email Password."
message = MIMEMultipart()
message["Subject"] = subject
message["From"] = email_sender
message["To"] = to
signature = self.get_tool_config('EMAIL_SIGNATURE')
if signature:
body += f"\n{signature}"
message.attach(MIMEText(body, 'plain'))
if attachment:
message.attach(attachment)
send_to_draft = self.get_tool_config('EMAIL_DRAFT_MODE') or "FALSE"
if send_to_draft.upper() == "TRUE":
send_to_draft = True
else:
send_to_draft = False
if send_to_draft:
draft_folder = self.get_tool_config('EMAIL_DRAFT_FOLDER')
imap_server = self.get_tool_config('EMAIL_IMAP_SERVER')
conn = ImapEmail().imap_open(draft_folder, email_sender, email_password, imap_server)
conn.append(
draft_folder,
"",
imaplib.Time2Internaldate(time.time()),
str(message).encode("UTF-8")
)
return f"Email went to {draft_folder}"
if message["To"] == "example@example.com":
return "Error: Email Not Sent. Enter an Email Address."
else:
smtp_host = self.get_tool_config('EMAIL_SMTP_HOST')
smtp_port = self.get_tool_config('EMAIL_SMTP_PORT')
with smtplib.SMTP(smtp_host, smtp_port) as smtp:
smtp.ehlo()
smtp.starttls()
smtp.login(email_sender, email_password)
smtp.send_message(message)
smtp.quit()
return f"Email was sent to {to}"
================================================
FILE: superagi/tools/file/__init__.py
================================================
================================================
FILE: superagi/tools/file/append_file.py
================================================
import os
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent_execution import AgentExecution
from superagi.tools.base_tool import BaseTool
from superagi.models.agent import Agent
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
from superagi.resource_manager.file_manager import FileManager
class AppendFileInput(BaseModel):
"""Input for CopyFileTool."""
file_name: str = Field(..., description="Name of the file to write")
content: str = Field(..., description="The text to append to the file")
class AppendFileTool(BaseTool):
"""
Append File tool
Attributes:
name : The name.
agent_id: The agent id.
description : The description.
args_schema : The args schema.
"""
name: str = "Append File"
agent_id: int = None
agent_execution_id: int = None
args_schema: Type[BaseModel] = AppendFileInput
description: str = "Append text to a file"
resource_manager: Optional[FileManager] = None
def _execute(self, file_name: str, content: str):
"""
Execute the append file tool.
Args:
file_name : The name of the file to write.
content : The text to append to the file.
Returns:
success or error message.
"""
final_path = ResourceHelper.get_agent_write_resource_path(file_name, Agent.get_agent_from_id(
session=self.toolkit_config.session,
agent_id=self.agent_id),
AgentExecution.get_agent_execution_from_id(
session=self.toolkit_config.session,
agent_execution_id=self.agent_execution_id))
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
previous_content = self.get_previous_content(final_path)
if previous_content is None:
return "Append file only supported for .txt Files."
if not previous_content:
return "File not Found."
S3Helper().delete_file(final_path)
new_content = previous_content + content
return self.resource_manager.write_file(file_name, new_content)
try:
directory = os.path.dirname(final_path)
os.makedirs(directory, exist_ok=True)
with open(final_path, 'a+', encoding="utf-8") as file:
file.write(content)
return "File written to successfully."
except Exception as err:
return f"Error: {err}"
def get_previous_content(self, final_path):
if final_path.split('/')[-1].lower().endswith('.txt'):
try:
return S3Helper().read_from_s3(final_path)
except Exception:
return False
================================================
FILE: superagi/tools/file/delete_file.py
================================================
import os
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent import Agent
from superagi.tools.base_tool import BaseTool
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
class DeleteFileInput(BaseModel):
"""Input for CopyFileTool."""
file_name: str = Field(..., description="Name of the file to delete")
class DeleteFileTool(BaseTool):
"""
Delete File tool
Attributes:
name : The name.
agent_id: The agent id.
description : The description.
args_schema : The args schema.
"""
name: str = "Delete File"
agent_id: int = None
agent_execution_id:int = None
args_schema: Type[BaseModel] = DeleteFileInput
description: str = "Delete a file"
def _execute(self, file_name: str):
"""
Execute the delete file tool.
Args:
file_name : The name of the file to delete.
Returns:
success or error message.
"""
final_path = ResourceHelper.get_agent_write_resource_path(file_name, Agent.get_agent_from_id(
session=self.toolkit_config.session,
agent_id=self.agent_id),
AgentExecution.get_agent_execution_from_id(
session=self.toolkit_config.session,
agent_execution_id=self.agent_execution_id))
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
try:
S3Helper().delete_file(final_path)
return "File deleted successfully."
except Exception as err:
return f"Error: {err}"
else:
try:
os.remove(final_path)
return "File deleted successfully."
except Exception as err:
return f"Error: {err}"
================================================
FILE: superagi/tools/file/file_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.file.append_file import AppendFileTool
from superagi.tools.file.delete_file import DeleteFileTool
from superagi.tools.file.list_files import ListFileTool
from superagi.tools.file.read_file import ReadFileTool
from superagi.tools.file.write_file import WriteFileTool
from superagi.types.key_type import ToolConfigKeyType
from superagi.models.tool_config import ToolConfig
class FileToolkit(BaseToolkit, ABC):
name: str = "File Toolkit"
description: str = "File Tool kit contains all tools related to file operations"
def get_tools(self) -> List[BaseTool]:
return [AppendFileTool(), DeleteFileTool(), ListFileTool(), ReadFileTool(), WriteFileTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return []
================================================
FILE: superagi/tools/file/list_files.py
================================================
import os
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.tools.base_tool import BaseTool
from superagi.models.agent import Agent
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
class ListFileInput(BaseModel):
pass
class ListFileTool(BaseTool):
"""
List File tool
Attributes:
name : The name.
agent_id: The agent id.
description : The description.
args_schema : The args schema.
"""
name: str = "List File"
agent_id: int = None
args_schema: Type[BaseModel] = ListFileInput
description: str = "lists files in a directory recursively"
def _execute(self):
"""
Execute the list file tool.
Args:
directory : The directory to list files in.
Returns:
list of files in directory.
"""
input_directory = ResourceHelper.get_root_input_dir()
#output_directory = ResourceHelper.get_root_output_dir()
if "{agent_id}" in input_directory:
input_directory = ResourceHelper.get_formatted_agent_level_path(agent=Agent
.get_agent_from_id(session=self
.toolkit_config.session,
agent_id=self.agent_id),
path=input_directory)
# if "{agent_id}" in output_directory:
# output_directory = output_directory.replace("{agent_id}", str(self.agent_id))
input_files = self.list_files(input_directory)
# output_files = self.list_files(output_directory)
return input_files #+ output_files
def list_files(self, directory):
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
return S3Helper().list_files_from_s3(directory)
found_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.startswith(".") or "__pycache__" in root:
continue
# relative_path = os.path.join(root, file)
# input_directory = ResourceHelper.get_root_input_dir()
# relative_path = relative_path.split(input_directory)[1]
found_files.append(file)
return found_files
================================================
FILE: superagi/tools/file/read_file.py
================================================
import os
from typing import Type, Optional
import ebooklib
import bs4
from bs4 import BeautifulSoup
from pydantic import BaseModel, Field
from ebooklib import epub
from superagi.helper.validate_csv import correct_csv_encoding
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.models.agent_execution import AgentExecution
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.models.agent import Agent
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from unstructured.partition.auto import partition
from superagi.lib.logger import logger
class ReadFileSchema(BaseModel):
"""Input for CopyFileTool."""
file_name: str = Field(..., description="Path of the file to read")
class ReadFileTool(BaseTool):
"""
Read File tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "Read File"
agent_id: int = None
agent_execution_id: int = None
args_schema: Type[BaseModel] = ReadFileSchema
description: str = "Reads the file content in a specified location"
resource_manager: Optional[FileManager] = None
def _execute(self, file_name: str):
"""
Execute the read file tool.
Args:
file_name : The name of the file to read.
Returns:
The file content and the file name
"""
final_path = ResourceHelper.get_agent_read_resource_path(file_name, agent=Agent.get_agent_from_id(
session=self.toolkit_config.session, agent_id=self.agent_id), agent_execution=AgentExecution
.get_agent_execution_from_id(session=self
.toolkit_config.session,
agent_execution_id=self
.agent_execution_id))
temporary_file_path = None
final_name = final_path.split('/')[-1]
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
if final_path.split('/')[-1].lower().endswith('.txt'):
return S3Helper().read_from_s3(final_path)
else:
save_directory = "/"
temporary_file_path = save_directory + file_name
with open(temporary_file_path, "wb") as f:
contents = S3Helper().read_binary_from_s3(final_path)
f.write(contents)
if final_path is None or not os.path.exists(final_path) and temporary_file_path is None:
raise FileNotFoundError(f"File '{file_name}' not found.")
directory = os.path.dirname(final_path)
os.makedirs(directory, exist_ok=True)
if temporary_file_path is not None:
final_path = temporary_file_path
# Check if the file is an .epub file
if final_path.lower().endswith('.epub'):
# Use ebooklib to read the epub file
book = epub.read_epub(final_path)
# Get the text content from each item in the book
content = []
for item in book.get_items_of_type(ebooklib.ITEM_DOCUMENT):
soup = BeautifulSoup(item.get_content(), 'html.parser')
content.append(soup.get_text())
content = "\n".join(content)
else:
if final_path.endswith('.csv'):
correct_csv_encoding(final_path)
elements = partition(final_path)
content = "\n\n".join([str(el) for el in elements])
if temporary_file_path is not None:
os.remove(temporary_file_path)
return content
================================================
FILE: superagi/tools/file/write_file.py
================================================
from typing import Type, Optional
from pydantic import BaseModel, Field
# from superagi.helper.s3_helper import upload_to_s3
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
# from superagi.helper.s3_helper import upload_to_s3
class WriteFileInput(BaseModel):
"""Input for CopyFileTool."""
file_name: str = Field(..., description="Name of the file to write. Only include the file name. Don't include path.")
content: str = Field(..., description="File content to write")
class WriteFileTool(BaseTool):
"""
Write File tool
Attributes:
name : The name.
description : The description.
agent_id: The agent id.
args_schema : The args schema.
resource_manager: File resource manager.
"""
name: str = "Write File"
args_schema: Type[BaseModel] = WriteFileInput
description: str = "Writes text to a file"
agent_id: int = None
resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, file_name: str, content: str):
"""
Execute the write file tool.
Args:
file_name : The name of the file to write.
content : The text to write to the file.
Returns:
success message if message is file written successfully or failure message if writing file fails.
"""
return self.resource_manager.write_file(file_name, content)
================================================
FILE: superagi/tools/github/README.MD
================================================
# SuperAGI GitHub Tool
The SuperAGI GitHub Tool enables users to perform various operations on GitHub repositories which include adding files or folders, deleting files, and searching for files or folders within a repository.
## 💡 Features
1. **Add Files or Folders:** With SuperAGI's GitHub Tool, you can easily add files or folders to a GitHub repository
2. **Delete Files:** Remove files from a GitHub repository effortlessly using SuperAGI's GitHub Tool.
3. **Search for Files or Folders:** Find specific files or folders within a GitHub repository using SuperAGI's GitHub Tool.
## ⚙️ Installation
### 🛠 **Setting Up SuperAGI**
Set up SuperAGI by following the instructions provided in the [SuperAGI repository's README file](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.md).
### 🔧 **Add GitHub Configuration Settings in SuperAGI Dashboard**
Add the following configuration settings to the GitHub Toolkit Page:
1. _GitHub Access Token:_
- Obtain a GitHub access token with the necessary permissions for accessing and modifying repositories.
- Go to Settings in your GitHub Account. Then go to Developer Settings.
- Click on "Personal access tokens". Then click on "Tokens (classic)".

- Click on "Generate new token". Then choose "Generate new token (classic)".

- Write a Note about what the token is for and choose an appropriate expiration date.

- Select all the scopes (In this way, users won't have to create new Access Tokens every time we add new scopes to the code).
- Click on Generate New Token.
- Copy the token and save it as a separate file.
2. _Github User Name:_
- You can find your GitHub username on your GitHub Profile.
3. _Configuring in SuperAGI Dashboard:_
-You can add your Generated Token and your Username to the GitHub Toolkit Page.
## Running SuperAGI GitHub Tool
1. **Add Files or Folders:**
To add a file or folder to a GitHub repository, specify the repository and the owner's UserName and the path where the file/folder should be added to your goal. SuperAGI will upload it to the repository and automatically raise a PR for it. By default, it'll pick the main branch, if you want to add it to any other branch you have to mention it in the goal.
2. **Delete Files:**
To delete a file from a GitHub repository, mention the repository, owner's UserName and provide the path to the file you want to delete in your goal. SuperAGI will handle the deletion process and raise a PR for it. By default, it'll pick the main branch, if you want to delete it to any other branch you have to mention it in the goal.
3. **Search for Files or Folders**
To search for files or folders within a GitHub repository, specify the repository, and owner's UserName and provide the name or path of the file/folder you're looking for in your goal. SuperAGI will provide you with the search results. By default, it'll pick the main branch, if you want to search in any other branch you have to mention it in the goal.
================================================
FILE: superagi/tools/github/__init__.py
================================================
================================================
FILE: superagi/tools/github/add_file.py
================================================
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.github_helper import GithubHelper
from superagi.tools.base_tool import BaseTool
class GithubAddFileSchema(BaseModel):
# """Input for CopyFileTool."""
repository_name: str = Field(
...,
description="Repository name in which file hase to be added",
)
base_branch: str = Field(
...,
description="branch to interact with",
)
file_name: str = Field(
...,
description="file name to be added to repository",
)
folder_path: str = Field(
...,
description="folder path for the file to be stored",
)
commit_message: str = Field(
...,
description="clear description of the contents of file",
)
repository_owner: str = Field(
...,
description="Owner of the github repository",
)
class GithubAddFileTool(BaseTool):
"""
Add File tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "Github Add File"
args_schema: Type[BaseModel] = GithubAddFileSchema
description: str = "Add a file or folder to a particular github repository"
agent_id: int = None
agent_execution_id: int = None
def _execute(self, repository_name: str, base_branch: str, commit_message: str, repository_owner: str,
file_name='.gitkeep', folder_path=None) -> str:
"""
Execute the add file tool.
Args:
repository_name : The name of the repository to add file to.
base_branch : The branch to interact with.
commit_message : Clear description of the contents of file.
repository_owner : Owner of the GitHub repository.
file_name : The name of the file to add.
folder_path : The path of the folder to add the file to.
Returns:
Pull request success message if pull request is created successfully else error message.
"""
session = self.toolkit_config.session
try:
github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
github_username = self.get_tool_config("GITHUB_USERNAME")
github_helper = GithubHelper(github_access_token, github_username)
head_branch = 'new-file'
headers = {
"Authorization": f"token {github_access_token}" if github_access_token else None,
"Content-Type": "application/vnd.github+json"
}
if repository_owner != github_username:
fork_response = github_helper.make_fork(repository_owner, repository_name, base_branch, headers)
branch_response = github_helper.create_branch(repository_name, base_branch, head_branch, headers)
file_response = github_helper.add_file(repository_owner, repository_name, file_name, folder_path,
head_branch, base_branch, headers, commit_message, self.agent_id, self.agent_execution_id, session)
pr_response = github_helper.create_pull_request(repository_owner, repository_name, head_branch, base_branch,
headers)
if (pr_response == 201 or pr_response == 422) and (file_response == 201 or file_response == 422):
return "Pull request to add file/folder has been created"
else:
return "Error while adding file."
except Exception as err:
return f"Error: Unable to add file/folder to repository {err}"
================================================
FILE: superagi/tools/github/delete_file.py
================================================
from typing import Type
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from superagi.helper.github_helper import GithubHelper
from superagi.lib.logger import logger
class GithubDeleteFileSchema(BaseModel):
# """Input for CopyFileTool."""
repository_name: str = Field(
...,
description="Repository name in which file hase to be deleted",
)
base_branch: str = Field(
...,
description="branch to interact with",
)
file_name: str = Field(
...,
description="file name to be deleted in the repository",
)
folder_path: str = Field(
...,
description="folder path in which file to be deleted is present",
)
commit_message: str = Field(
...,
description="clear description of files that are being deleted",
)
repository_owner: str = Field(
...,
description="Owner of the github repository",
)
class GithubDeleteFileTool(BaseTool):
"""
Delete File tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "Github Delete File"
args_schema: Type[BaseModel] = GithubDeleteFileSchema
description: str = "Delete a file or folder inside a particular github repository"
def _execute(self, repository_name: str, base_branch: str, file_name: str, commit_message: str,
repository_owner: str, folder_path=None) -> str:
"""
Execute the delete file tool.
Args:
repository_name : The name of the repository to delete file from.
base_branch : The branch to interact with.
file_name : The name of the file to delete.
commit_message : Clear description of the contents of file.
repository_owner : Owner of the GitHub repository.
folder_path : The path of the folder to delete the file from.
Returns:
success message mentioning the pull request name for the delete file operation. or error message.
"""
try:
github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
github_username = self.get_tool_config("GITHUB_USERNAME")
github_helper = GithubHelper(github_access_token, github_username)
head_branch = 'new-file'
headers = {
"Authorization": f"token {github_access_token}" if github_access_token else None,
"Content-Type": "application/vnd.github+json"
}
if repository_owner != github_username:
fork_response = github_helper.make_fork(repository_owner, repository_name, base_branch, headers)
branch_response = github_helper.create_branch(repository_name, base_branch, head_branch, headers)
logger.info("branch_response", branch_response)
if branch_response == 201 or branch_response == 422:
github_helper.sync_branch(github_username, repository_name, base_branch, head_branch, headers)
file_response = github_helper.delete_file(repository_name, file_name, folder_path, commit_message,
head_branch, headers)
pr_response = github_helper.create_pull_request(repository_owner, repository_name, head_branch, base_branch,
headers)
if (pr_response == 201 or pr_response == 422) and (file_response == 200):
return f"Pull request to Delete {file_name} has been created"
else:
return "Error while deleting file"
except Exception as err:
return f"Error: Unable to delete file {file_name} in {repository_name} repository"
================================================
FILE: superagi/tools/github/fetch_pull_request.py
================================================
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.github_helper import GithubHelper
from superagi.llms.base_llm import BaseLlm
from superagi.tools.base_tool import BaseTool
class GithubFetchPullRequestSchema(BaseModel):
repository_name: str = Field(
...,
description="Repository name in which file hase to be added",
)
repository_owner: str = Field(
...,
description="Owner of the github repository",
)
time_in_seconds: int = Field(
...,
description="Gets pull requests from last `time_in_seconds` seconds",
)
class GithubFetchPullRequest(BaseTool):
"""
Fetch pull request tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
agent_id: The agent id.
agent_execution_id: The agent execution id.
"""
llm: Optional[BaseLlm] = None
name: str = "Github Fetch Pull Requests"
args_schema: Type[BaseModel] = GithubFetchPullRequestSchema
description: str = "Fetch pull requests from github"
agent_id: int = None
agent_execution_id: int = None
def _execute(self, repository_name: str, repository_owner: str, time_in_seconds: int = 86400) -> str:
"""
Execute the add file tool.
Args:
repository_name: The name of the repository to add file to.
repository_owner: Owner of the GitHub repository.
time_in_seconds: Gets pull requests from last `time_in_seconds` seconds
Returns:
List of all pull request ids
"""
try:
github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
github_username = self.get_tool_config("GITHUB_USERNAME")
github_helper = GithubHelper(github_access_token, github_username)
pull_request_urls = github_helper.get_pull_requests_created_in_last_x_seconds(repository_owner,
repository_name,
time_in_seconds)
return "Pull requests: " + str(pull_request_urls)
except Exception as err:
return f"Error: Unable to fetch pull requests {err}"
================================================
FILE: superagi/tools/github/github_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.github.add_file import GithubAddFileTool
from superagi.tools.github.delete_file import GithubDeleteFileTool
from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest
from superagi.tools.github.search_repo import GithubRepoSearchTool
from superagi.tools.github.review_pull_request import GithubReviewPullRequest
from superagi.types.key_type import ToolConfigKeyType
class GitHubToolkit(BaseToolkit, ABC):
name: str = "GitHub Toolkit"
description: str = "GitHub Tool Kit contains all github related to tool"
def get_tools(self) -> List[BaseTool]:
return [GithubAddFileTool(), GithubDeleteFileTool(), GithubRepoSearchTool(), GithubReviewPullRequest(),
GithubFetchPullRequest()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="GITHUB_ACCESS_TOKEN", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True),
ToolConfiguration(key="GITHUB_USERNAME", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False)
]
================================================
FILE: superagi/tools/github/prompts/code_review.txt
================================================
Your purpose is to act as a highly experienced software engineer and provide a thorough review of the code chunks and suggest code snippets to improve key areas such as:
- Logic
- Modularity
- Maintainability
- Complexity
Do not comment on minor code style issues, missing comments/documentation. Identify and resolve significant concerns to improve overall code quality while deliberately disregarding minor issues
Following is the github pull request diff content:
```
{{DIFF_CONTENT}}
```
Instructions:
1. Do not comment on existing lines and deleted lines.
2. Ignore the lines start with '-'.
3. Only consider lines starting with '+' for review.
4. Do not comment on frontend and graphql code.
Respond with only valid JSON conforming to the following schema:
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"comments": {
"type": "array",
"items": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file where the comment should be added."
},
"line": {
"type": "integer",
"description": "The line number where the comment should be added. "
},
"comment": {
"type": "string",
"description": "The content of the comment."
}
},
"required": ["file_name", "line", "comment"]
}
}
},
"required": ["comments"]
}
Ensure response is valid JSON conforming to the following schema.
================================================
FILE: superagi/tools/github/review_pull_request.py
================================================
import ast
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.github_helper import GithubHelper
from superagi.helper.json_cleaner import JsonCleaner
from superagi.helper.prompt_reader import PromptReader
from superagi.helper.token_counter import TokenCounter
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
class GithubReviewPullRequestSchema(BaseModel):
repository_name: str = Field(
...,
description="Repository name in which file hase to be added",
)
repository_owner: str = Field(
...,
description="Owner of the github repository",
)
pull_request_number: int = Field(
...,
description="Pull request number",
)
class GithubReviewPullRequest(BaseTool):
"""
Reviews the github pull request and adds comments inline
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name: str = "Github Review Pull Request"
args_schema: Type[BaseModel] = GithubReviewPullRequestSchema
description: str = "Add pull request for the github repository"
agent_id: int = None
agent_execution_id: int = None
def _execute(self, repository_name: str, repository_owner: str, pull_request_number: int) -> str:
"""
Execute the add file tool.
Args:
repository_name: The name of the repository to add file to.
repository_owner: Owner of the GitHub repository.
pull_request_number: pull request number
Returns:
Pull request success message if pull request is created successfully else error message.
"""
try:
github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
github_username = self.get_tool_config("GITHUB_USERNAME")
github_helper = GithubHelper(github_access_token, github_username)
pull_request_content = github_helper.get_pull_request_content(repository_owner, repository_name,
pull_request_number)
latest_commit_id = github_helper.get_latest_commit_id_of_pull_request(repository_owner, repository_name,
pull_request_number)
pull_request_arr = pull_request_content.split("diff --git")
organisation = Agent.find_org_by_agent_id(session=self.toolkit_config.session, agent_id=self.agent_id)
model_token_limit = TokenCounter(session=self.toolkit_config.session,
organisation_id=organisation.id).token_limit(self.llm.get_model())
pull_request_arr_parts = self.split_pull_request_content_into_multiple_parts(model_token_limit, pull_request_arr)
for content in pull_request_arr_parts:
self.run_code_review(github_helper, content, latest_commit_id, organisation, pull_request_number,
repository_name, repository_owner)
return "Added comments to the pull request:" + str(pull_request_number)
except Exception as err:
return f"Error: Unable to add comments to the pull request {err}"
def run_code_review(self, github_helper, content, latest_commit_id, organisation, pull_request_number,
repository_name, repository_owner):
prompt = PromptReader.read_tools_prompt(__file__, "code_review.txt")
prompt = prompt.replace("{{DIFF_CONTENT}}", content)
messages = [{"role": "system", "content": prompt}]
total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
token_limit = TokenCounter(session=self.toolkit_config.session,
organisation_id=organisation.id).token_limit(self.llm.get_model())
result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
response = result["content"]
if response.startswith("```") and response.endswith("```"):
response = "```".join(response.split("```")[1:-1])
response = JsonCleaner.extract_json_section(response)
comments = ast.literal_eval(response)
# Add comments in the pull request
for comment in comments['comments']:
line_number = self.get_exact_line_number(content, comment["file_path"], comment["line"])
github_helper.add_line_comment_to_pull_request(repository_owner, repository_name, pull_request_number,
latest_commit_id, comment["file_path"], line_number,
comment["comment"])
def split_pull_request_content_into_multiple_parts(self, model_token_limit: int, pull_request_arr):
pull_request_arr_parts = []
current_part = ""
for part in pull_request_arr:
total_tokens = TokenCounter.count_message_tokens([{"role": "user", "content": current_part}],
self.llm.get_model())
# we are using 60% of the model token limit
if total_tokens >= model_token_limit * 0.6:
# Add the current part to pull_request_arr_parts and reset current_sum and current_part
pull_request_arr_parts.append(current_part)
current_part = "diff --git" + part
else:
current_part += "diff --git" + part
pull_request_arr_parts.append(current_part)
return pull_request_arr_parts
def get_exact_line_number(self, diff_content, file_path, line_number):
last_content = diff_content[diff_content.index(file_path):]
last_content = last_content[last_content.index('@@'):]
return self.find_position_in_diff(last_content, line_number)
def find_position_in_diff(self, diff_content, target_line):
# Split the diff by lines and initialize variables
diff_lines = diff_content.split('\n')
position = 0
current_file_line_number = 0
# Loop through each line in the diff
for line in diff_lines:
position += 1 # Increment position for each line
if line.startswith('@@'):
# Reset the current file line number when encountering a new hunk
current_file_line_number = int(line.split('+')[1].split(',')[0]) - 1
elif not line.startswith('-'):
# Increment the current file line number for lines that are not deletions
current_file_line_number += 1
if current_file_line_number >= target_line:
# Return the position when the target line number is reached
return position
return position
================================================
FILE: superagi/tools/github/search_repo.py
================================================
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.github_helper import GithubHelper
from superagi.tools.base_tool import BaseTool
class GithubSearchRepoSchema(BaseModel):
repository_name: str = Field(
...,
description="Repository name in which we have to search",
)
repository_owner: str = Field(
...,
description="Owner of the github repository",
)
file_name: str = Field(
...,
description="Name of the file we need to fetch from the repository",
)
folder_path: str = Field(
...,
description="folder path in which file is present",
)
class GithubRepoSearchTool(BaseTool):
"""
Search File tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "GithubRepo Search"
description = (
"Search for a file inside a Github repository"
)
args_schema: Type[GithubSearchRepoSchema] = GithubSearchRepoSchema
def _execute(self, repository_owner: str, repository_name: str, file_name: str, folder_path=None) -> str:
"""
Execute the search file tool.
Args:
repository_owner : The owner of the repository to search file in.
repository_name : The name of the repository to search file in.
file_name : The name of the file to search.
folder_path : The path of the folder to search the file in.
Returns:
The content of the github file.
"""
github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
github_username = self.get_tool_config("GITHUB_USERNAME")
github_repo_search = GithubHelper(github_access_token, github_username)
try:
content = github_repo_search.get_content_in_file(repository_owner, repository_name, file_name, folder_path)
return content
except:
return "File not found"
================================================
FILE: superagi/tools/google_calendar/README.md
================================================
# SuperAGI - Google Calendar Toolkit
Introducing the Google Calendar Toolkit, a powerful integration for SuperAGI. With the Google Calendar toolkit, you have the ability to do the following:
1. **Create Calendar Events**
2. **List your Calendar Events**
3. **Fetch an event from your Calendar**
4. **Delete Calendar Events**
## ⚙️ Installation
### ⚒️ Setting up of SuperAGI
Set up SuperAGI by following the instructions given [here](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
# ✅ Quickstart Guide:
In order to get started with integrating Google Calendar with SuperAGI, you need to do the following:
## API Creation and OAuth Consent Screen
1. Go to Google Developer Console:
[https://console.cloud.google.com/](https://console.cloud.google.com/) & Create a new project. If you’re having an existing project, you can proceed with that as well:

2. After the project is created/you’re in your selected project, head to “APIs and Services”

3. Click on “ENABLED APIS AND SERVICES” and search for “Google Calendar”


4. Enable the API

5. Once the API is Enabled, go to “OAuth Consent Screen”

6. Select your User Type as “External” and click on "Create"

7. Fill in the required details such as the App Information, App Domain, Authorized Domain, and Developer contact information. Once filled in, click “Save and Continue”

8. On the next page, you don’t need to select the scopes. Proceed to “save and continue” and at the final page, review the process and click “Back to Dashboard”. With this, you’ve created your OAuth Consent Screen for Google Calendar.
9. You can go ahead and click the “Publish App”

## 🔧 Configuring endpoints & obtaining Client ID and Client Secret Key
In order to obtain the Client ID and Secret ID, you need to do the following steps:
1. Go to “Credentials” Page

2. Click on “Create Credentials” and click on “OAuth Client ID”


3. Once you click on OAuth Client ID, choose the type of application as “Web Application” and give it a name of your choice

4. Create JavaScript Origins and add the following details as shown in the image:

5. Go to Authorized redirect URIs and add the following URIs:
`https://app.superagi.com/api/google/oauth-tokens`
`http://localhost:3000/api/google/oauth-tokens`

6. Once you have added the Authorized redirect URIs, you can click “Create” to obtain the Client ID and Client Secret Key

7. Copy the Client ID and Secret Key and save it in a file.
## Configuring your Client ID, Secret Key and Authenticating Calendar with SuperAGI
Once the ClientID and Secret Key are obtained, you can configure and authorize Calendar to be used with SuperAGI by following these steps:
1. Add your Client ID and Client Secret Key on the toolkit page and click on “Update Changes”

2. Click on “Authenticate Tool” - which will now take you to the OAuth Flow.
Once the OAuth Authentication is complete, you can now start using SuperAGI Agents with Google Calendar!
================================================
FILE: superagi/tools/google_calendar/create_calendar_event.py
================================================
from typing import Any, Type
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
class CreateEventCalendarInput(BaseModel):
event_name: str = Field(..., description="Name of the event/meeting to be scheduled, if not given craete a name depending on description.")
description: str = Field(..., description="Description of the event/meeting to be scheduled.")
start_date: str = Field(..., description="Start date of the event to be scheduled in format 'yyyy-mm-dd', if no value is given keep the default value as 'None'.")
start_time: str = Field(..., description="Start time of the event to be scheduled in format 'hh:mm:ss', if no value is given keep the default value as 'None'.")
end_date: str = Field(..., description="End Date of the event to be scheduled in format 'yyyy-mm-dd', if no value is given keep the default value as 'None'.")
end_time: str = Field(..., description="End Time of the event to be scheduled in format 'hh:mm:ss', if no value is given keep the default value as 'None'.")
attendees: list = Field(..., description="List of attendees email ids to be invited for the event.")
location: str = Field(..., description="Geographical location of the event. if no value is given keep the default value as 'None'")
class CreateEventCalendarTool(BaseTool):
name: str = "Create Google Calendar Event"
args_schema: Type[BaseModel] = CreateEventCalendarInput
description: str = "Create an event for Google Calendar"
def _execute(self, event_name: str, description: str, attendees: list, start_date: str = 'None', start_time: str = 'None', end_date: str = 'None', end_time: str = 'None', location: str = 'None'):
session = self.toolkit_config.session
toolkit_id = self.toolkit_config.toolkit_id
service = GoogleCalendarCreds(session).get_credentials(toolkit_id)
if service["success"]:
service = service["service"]
else:
return f"Kindly connect to Google Calendar"
date_utc = CalendarDate().create_event_dates(service, start_date, start_time, end_date, end_time)
attendees_list = []
for attendee in attendees:
email_id = {
"email": attendee
}
attendees_list.append(email_id)
event = {
"summary": event_name,
"description": description,
"start": {
"dateTime": date_utc["start_datetime_utc"],
"timeZone": date_utc["timeZone"]
},
"end": {
"dateTime": date_utc["end_datetime_utc"],
"timeZone": date_utc["timeZone"]
},
"attendees": attendees_list
}
if location != "None":
event["location"] = location
else:
event["conferenceData"] = {
"createRequest": {
"requestId": f"meetSample123",
"conferenceSolutionKey": {
"type": "hangoutsMeet"
},
},
}
event = service.events().insert(calendarId="primary", body=event, conferenceDataVersion=1).execute()
output_str = f"Event {event_name} at {date_utc['start_datetime_utc']} created successfully, link for the event {event.get('htmlLink')}"
return output_str
================================================
FILE: superagi/tools/google_calendar/delete_calendar_event.py
================================================
import base64
from typing import Any, Type
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
class DeleteCalendarEventInput(BaseModel):
event_id: str = Field(..., description="The id of event to be deleted from Google Calendar. default value is None")
class DeleteCalendarEventTool(BaseTool):
name: str = "Delete Google Calendar Event"
args_schema: Type[BaseModel] = DeleteCalendarEventInput
description: str = "Delete an event from Google Calendar"
def _execute(self, event_id: str):
service = GoogleCalendarCreds(self.toolkit_config.session).get_credentials(self.toolkit_config.toolkit_id)
if service["success"]:
service = service["service"]
else:
return f"Kindly connect to Google Calendar"
if event_id == "None":
return f"Add Event ID to delete an event from Google Calendar"
else:
if len(event_id) % 4 != 0:
event_id += "=" * (4 - (len(event_id) % 4))
decoded_id = base64.b64decode(event_id)
eid = decoded_id.decode("utf-8")
eid = eid.split(" ", 1)[0]
result = service.events().delete(
calendarId = "primary",
eventId = eid
).execute()
return f"Event Successfully deleted from your Google Calendar"
================================================
FILE: superagi/tools/google_calendar/event_details_calendar.py
================================================
import base64
from typing import Any, Type
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.tools.base_tool import BaseTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
class EventDetailsCalendarInput(BaseModel):
event_id: str = Field(..., description="The id of event to be fetched from Google Calendar. if no value is given keep default value is None")
class EventDetailsCalendarTool(BaseTool):
name: str = "Fetch Google Calendar Event"
args_schema: Type[BaseModel] = EventDetailsCalendarInput
description: str = "Fetch an event from Google Calendar"
def _execute(self, event_id: str):
service = GoogleCalendarCreds(self.toolkit_config.session).get_credentials(self.toolkit_config.toolkit_id)
if service["success"]:
service = service["service"]
else:
return f"Kindly connect to Google Calendar"
if event_id == "None":
return f"Add Event ID to fetch details of an event from Google Calendar"
else:
if len(event_id) % 4 != 0:
event_id += "=" * (4 - (len(event_id) % 4))
decoded_id = base64.b64decode(event_id)
eid = decoded_id.decode("utf-8")
eid = eid.split(" ", 1)[0]
result = service.events().get(
calendarId = "primary",
eventId = eid
).execute()
if "summary" in result:
summary = result['summary']
if result['start'] and result['end']:
start_date = result['start']['dateTime']
end_date = result['end']['dateTime']
attendees = []
if "attendees" in result:
for attendee in result['attendees']:
attendees.append(attendee['email'])
attendees_str = ','.join(attendees)
output_str = f"Event details for the event id '{event_id}' is - \nSummary : {summary}\nStart Date and Time : {start_date}\nEnd Date and Time : {end_date}\nAttendees : {attendees_str}"
return output_str
================================================
FILE: superagi/tools/google_calendar/google_calendar_toolkit.py
================================================
from abc import ABC
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from typing import Type, List
from superagi.tools.google_calendar.create_calendar_event import CreateEventCalendarTool
from superagi.tools.google_calendar.delete_calendar_event import DeleteCalendarEventTool
from superagi.tools.google_calendar.list_calendar_events import ListCalendarEventsTool
from superagi.tools.google_calendar.event_details_calendar import EventDetailsCalendarTool
from superagi.types.key_type import ToolConfigKeyType
class GoogleCalendarToolKit(BaseToolkit, ABC):
name: str = "Google Calendar Toolkit"
description: str = "Google Calendar Tool kit contains all tools related to Google Calendar"
def get_tools(self) -> List[BaseTool]:
return [CreateEventCalendarTool(), DeleteCalendarEventTool(), ListCalendarEventsTool(), EventDetailsCalendarTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="GOOGLE_CLIENT_ID", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = False),
ToolConfiguration(key="GOOGLE_CLIENT_SECRET", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret= True)
]
================================================
FILE: superagi/tools/google_calendar/list_calendar_events.py
================================================
import os
import csv
from datetime import datetime
from typing import Type
from superagi.config.config import get_config
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
from superagi.resource_manager.file_manager import FileManager
from superagi.helper.s3_helper import S3Helper
from urllib.parse import urlparse, parse_qs
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.lib.logger import logger
class ListCalendarEventsInput(BaseModel):
start_time: str = Field(..., description="A string variable storing the start time to return events from the calendar in format 'HH:MM:SS'. if no value is given keep default value as 'None'")
start_date: str = Field(..., description="A string variable storing the start date to return events from the calendar in format 'yyyy-mm-dd' in a string variable, if no value is given keep default value as 'None'.")
end_date: str = Field(..., description="A string variable storing the end date to return events from the calendar in format 'yyyy-mm-dd' in a string variable, if no value is given keep default value as 'None'.")
end_time: str = Field(..., description="A string variable storing the end time to return events from the calendar in format 'HH:MM:SS'. if no value is given keep default value as 'None'")
class ListCalendarEventsTool(BaseTool):
name: str = "List Google Calendar Events"
args_schema: Type[BaseModel] = ListCalendarEventsInput
description: str = "Get the list of all the events from Google Calendar"
agent_id: int = None
resource_manager: FileManager = None
def _execute(self, start_time: str = 'None', start_date: str = 'None', end_date: str = 'None', end_time: str = 'None'):
service = self.get_google_calendar_service()
if not service["success"]:
return f"Kindly connect to Google Calendar"
date_utc = CalendarDate().get_date_utc(start_date, end_date, start_time, end_time, service["service"])
event_results = self.get_event_results(service["service"], date_utc)
if not event_results:
return f"No events found for the given date and time range."
csv_data = self.generate_csv_data(event_results)
file_name = self.create_output_file()
if file_name is not None:
self.resource_manager.write_csv_file(file_name, csv_data)
return f"List of Google Calendar Events month successfully stored in {file_name}."
def get_google_calendar_service(self):
return GoogleCalendarCreds(self.toolkit_config.session).get_credentials(self.toolkit_config.toolkit_id)
def get_event_results(self, service, date_utc):
return (
service.events().list(
calendarId="primary",
timeMin=date_utc['start_datetime_utc'],
timeMax=date_utc['end_datetime_utc'],
singleEvents=True,
orderBy="startTime",
).execute()
)
def generate_csv_data(self, event_results):
csv_data = [['Event ID', 'Event Name', 'Start Time', 'End Time', 'Attendees']]
for item in event_results['items']:
event_id, summary, start_date, end_date, attendees_str = self.parse_event_data(item)
csv_data.append([event_id, summary, start_date, end_date, attendees_str])
return csv_data
def parse_event_data(self, item):
eid_url = item["htmlLink"]
parsed_url = urlparse(eid_url)
query_parameters = parse_qs(parsed_url.query)
event_id = query_parameters.get('eid', [None])[0]
summary = item.get('summary', '')
start_date = item['start'].get('dateTime', '')
end_date = item['end'].get('dateTime', '')
attendees = [attendee['email'] for attendee in item.get('attendees', [])]
attendees_str = ','.join(attendees)
return event_id, summary, start_date, end_date, attendees_str
def create_output_file(self):
file = datetime.now()
file = file.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
file_name = "Google_Calendar_" + file + ".csv"
return file_name
================================================
FILE: superagi/tools/google_search/README.MD
================================================
# SuperAGI Google Search Tool
The SuperAGI Google Search Tool helps users perform a Google search and extract snippets and webpages.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
If you've put the correct Google API key and Custom Search Engine ID, you'll be able to use the Google Search Tool as well.
## Running SuperAGI Google Search Tool
You can simply ask your agent about latest information regarding anything in the world and your agent will be able to browse the internet to get that information for you.
================================================
FILE: superagi/tools/google_search/__init__.py
================================================
================================================
FILE: superagi/tools/google_search/google_search.py
================================================
import json
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.google_search import GoogleSearchWrap
from superagi.helper.token_counter import TokenCounter
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
class GoogleSearchSchema(BaseModel):
query: str = Field(
...,
description="The search query for Google search.",
)
class GoogleSearchTool(BaseTool):
"""
Google Search tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "GoogleSearch"
agent_id: int = None
agent_execution_id: int = None
description = (
"A tool for performing a Google search and extracting snippets and webpages."
"Input should be a search query."
)
args_schema: Type[GoogleSearchSchema] = GoogleSearchSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, query: str) -> tuple:
"""
Execute the Google search tool.
Args:
query : The query to search for.
Returns:
Search result summary along with related links
"""
api_key = self.get_tool_config("GOOGLE_API_KEY")
search_engine_id = self.get_tool_config("SEARCH_ENGINE_ID")
num_results = 10
num_pages = 1
num_extracts = 3
google_search = GoogleSearchWrap(api_key, search_engine_id, num_results, num_pages, num_extracts)
snippets, webpages, links = google_search.get_result(query)
results = []
i = 0
for webpage in webpages:
results.append({"title": snippets[i], "body": webpage, "links": links[i]})
i += 1
if TokenCounter.count_text_tokens(json.dumps(results)) > 3000:
break
summary = self.summarise_result(query, results)
links = [result["links"] for result in results if len(result["links"]) > 0]
if len(links) > 0:
return summary + "\n\nLinks:\n" + "\n".join("- " + link for link in links[:3])
return summary
def summarise_result(self, query, snippets):
"""
Summarise the result of a Google search.
Args:
query : The query to search for.
snippets (list): A list of snippets from the search.
Returns:
A summary of the search result.
"""
summarize_prompt ="""Summarize the following text `{snippets}`
Write a concise or as descriptive as necessary and attempt to
answer the query: `{query}` as best as possible. Use markdown formatting for
longer responses."""
summarize_prompt = summarize_prompt.replace("{snippets}", str(snippets))
summarize_prompt = summarize_prompt.replace("{query}", query)
messages = [{"role": "system", "content": summarize_prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
return result["content"]
================================================
FILE: superagi/tools/google_search/google_search_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.google_search.google_search import GoogleSearchTool
from superagi.models.tool_config import ToolConfig
from superagi.types.key_type import ToolConfigKeyType
class GoogleSearchToolkit(BaseToolkit, ABC):
name: str = "Google Search Toolkit"
description: str = "Toolkit containing tools for performing Google search and extracting snippets and webpages"
def get_tools(self) -> List[BaseTool]:
return [GoogleSearchTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="GOOGLE_API_KEY", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True),
ToolConfiguration(key="SEARCH_ENGINE_ID", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=True)
]
================================================
FILE: superagi/tools/google_serp_search/README.md
================================================
# SuperAGI Google SERP Search Toolkit
The SuperAGI Google Search Toolkit helps users perform a Google search and extract snippets and webpages.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
### 🔧 **Add Google Serp Search API Key in SuperAGI Dashboard**
1. Register an account at [https://serper.dev/](https://serper.dev/) with your Email ID.
2. Your Private API Key would be made. Copy that and save it in a separate text file.

3. Open up the Google SERP Toolkit page in SuperAGI's Dashboard and paste your Private API Key.
## Running SuperAGI Google Search Serp Tool
You can simply ask your agent about the latest information regarding anything and your agent will be able to browse the internet to get that information for you.
================================================
FILE: superagi/tools/google_serp_search/__init__.py
================================================
================================================
FILE: superagi/tools/google_serp_search/google_serp_search.py
================================================
from typing import Type, Optional, Any
from pydantic import BaseModel, Field
import aiohttp
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.google_serp import GoogleSerpApiWrap
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
import os
import json
class GoogleSerpSchema(BaseModel):
query: str = Field(
...,
description="The search query for Google SERP.",
)
'''Google search using serper.dev. Use server.dev api keys'''
class GoogleSerpTool(BaseTool):
"""
Google Search tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "GoogleSerp"
agent_id: int = None
agent_execution_id: int = None
description = (
"A tool for performing a Google SERP search and extracting snippets and webpages."
"Input should be a search query."
)
args_schema: Type[GoogleSerpSchema] = GoogleSerpSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, query: str) -> tuple:
"""
Execute the Google search tool.
Args:
query : The query to search for.
Returns:
Search result summary along with related links
"""
api_key = self.get_tool_config("SERP_API_KEY")
serp_api = GoogleSerpApiWrap(api_key)
response = serp_api.search_run(query)
summary = self.summarise_result(query, response["snippets"])
if response["links"]:
return summary + "\n\nLinks:\n" + "\n".join("- " + link for link in response["links"][:3])
return summary
def summarise_result(self, query, snippets):
summarize_prompt = """Summarize the following text `{snippets}`
Write a concise or as descriptive as necessary and attempt to
answer the query: `{query}` as best as possible. Use markdown formatting for
longer responses."""
summarize_prompt = summarize_prompt.replace("{snippets}", str(snippets))
summarize_prompt = summarize_prompt.replace("{query}", query)
messages = [{"role": "system", "content": summarize_prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
return result["content"]
================================================
FILE: superagi/tools/google_serp_search/google_serp_search_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.google_serp_search.google_serp_search import GoogleSerpTool
from superagi.models.tool_config import ToolConfig
from superagi.types.key_type import ToolConfigKeyType
class GoogleSerpToolkit(BaseToolkit, ABC):
name: str = "Google SERP Toolkit"
description: str = "Toolkit containing tools for performing Google SERP search and extracting snippets and webpages"
def get_tools(self) -> List[BaseTool]:
return [GoogleSerpTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="SERP_API_KEY", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True)
]
================================================
FILE: superagi/tools/image_generation/README.MD
================================================
# SuperAGI Image Generation Tool
The SuperAGI Image Generation Tool helps you generate an image with a prompt using DALL-E.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
If you've put the correct OpenAI key during the installation, you'd be able to use the Image Generation tool as well.
## Running SuperAGI Image Generation Tool
You can simply put one of the goals of your agent to create an image and the agent will pick up the Image Generation Tool and will place it in the Output folder of the Resource Manager, from where you'll be able to download it.
================================================
FILE: superagi/tools/image_generation/README.STABLE_DIFFUSION.md
================================================
## SuperAGI Stable Diffusion Toolkit
Introducing Stable Diffusion Integration with SuperAGI
You can now use SuperAGI to summon Stable Diffusion to create true-to-life images which opens up a whole new range of possibilities.
# ⚙️ Installation
## 🛠️ Setting up SuperAGI
Set up SuperAGI by following the instruction given [here](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
## 🔧Configuring API from DreamStudio
You can now get your API Key from Dream Studio to use Stable Diffusion by following the instructions below:
1. Create an Account/Login with [DreamStudio.ai](http://DreamStudio.ai)

2. Click on the Profile Icon at the top right which will take you to the settings page. Once you have reached the settings page, you can now get your API keys

3. Copy the API Key and save it in a separate file
## 🛠️Configuring Stable Diffusion with SuperAGI
You can configure SuperAGI with Stable Diffusion using the following steps:
1. Navigate to the “****************Toolkit”**************** Page in SuperAGI’s Dashboard and select “****************Image Generation Toolkit”****************

2. Once you’ve clicked Image Generation Toolkit, it will open a page asking you for the API Key and the Model Engine. You can enter the generated API key from Dream Studio here

3. If you would like to go in-depth with the model of Stable Diffusion, you can choose between the following engine IDs:
- 'stable-diffusion-v1'
- 'stable-diffusion-v1-5'
- 'stable-diffusion-512-v2-0'
- 'stable-diffusion-768-v2-0'
- 'stable-diffusion-512-v2-1'
- ’stable-diffusion-768-v2-1'
- 'stable-diffusion-xl-beta-v2-2-2’
You have now successfully configured Stable Diffusion with SuperAGI!
================================================
FILE: superagi/tools/image_generation/__init__.py
================================================
================================================
FILE: superagi/tools/image_generation/dalle_image_gen.py
================================================
from typing import Type, Optional
import requests
from pydantic import BaseModel, Field
from superagi.image_llms.openai_dalle import OpenAiDalle
from superagi.llms.base_llm import BaseLlm
from superagi.resource_manager.file_manager import FileManager
from superagi.models.toolkit import Toolkit
from superagi.models.configuration import Configuration
from superagi.tools.base_tool import BaseTool
class DalleImageGenInput(BaseModel):
prompt: str = Field(..., description="Prompt for Image Generation to be used by Dalle.")
size: int = Field(..., description="Size of the image to be Generated. default size is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
image_names: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
class DalleImageGenTool(BaseTool):
"""
Dalle Image Generation tool
Attributes:
name : Name of the tool
description : The description
args_schema : The args schema
agent_id : The agent id
resource_manager : Manages the file resources
"""
name: str = "DalleImageGeneration"
args_schema: Type[BaseModel] = DalleImageGenInput
description: str = "Generate Images using Dalle"
agent_id: int = None
agent_execution_id: int = None
resource_manager: Optional[FileManager] = None
# class Config:
# arbitrary_types_allowed = True
def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2):
"""
Execute the Dalle Image Generation tool.
Args:
prompt : The prompt for image generation.
size : The size of the image to be generated.
num : The number of images to be generated.
image_names (list): The name of the image to be generated.
Returns:
Image generated successfully message if image is generated or error message.
"""
session = self.toolkit_config.session
toolkit = session.query(Toolkit).filter(Toolkit.id == self.toolkit_config.toolkit_id).first()
organisation_id = toolkit.organisation_id
if size not in [256, 512, 1024]:
size = min([256, 512, 1024], key=lambda x: abs(x - size))
api_key = self.get_tool_config("OPENAI_API_KEY")
if api_key is None:
return "Enter your OpenAi api key in the configuration"
response = OpenAiDalle(api_key=api_key, number_of_results=num).generate_image(
prompt, size)
response = response.__dict__
response = response['_previous']['data']
for i in range(num):
data = requests.get(response[i]['url']).content
self.resource_manager.write_binary_file(image_names[i], data)
return "Images downloaded successfully"
================================================
FILE: superagi/tools/image_generation/image_generation_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.image_generation.dalle_image_gen import DalleImageGenTool
from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool
from superagi.types.key_type import ToolConfigKeyType
class ImageGenToolkit(BaseToolkit, ABC):
name: str = "Image Generation Toolkit"
description: str = "Toolkit containing a tool for generating images"
def get_tools(self) -> List[BaseTool]:
return [DalleImageGenTool(), StableDiffusionImageGenTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="STABILITY_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret = True),
ToolConfiguration(key="ENGINE_ID", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=False),
ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True)
]
================================================
FILE: superagi/tools/image_generation/stable_diffusion_image_gen.py
================================================
import base64
from io import BytesIO
from typing import Type, Optional
import requests
from PIL import Image
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent import Agent
class StableDiffusionImageGenInput(BaseModel):
prompt: str = Field(..., description="Prompt for Image Generation to be used by Stable Diffusion. The prompt should be as descriptive as possible and mention all the details of the image to be generated")
height: int = Field(..., description="Height of the image to be Generated. default height is 512")
width: int = Field(..., description="Width of the image to be Generated. default width is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
steps: int = Field(..., description="Number of diffusion steps to run. default steps are 50")
image_names: list = Field(...,
description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
class StableDiffusionImageGenTool(BaseTool):
"""
Stable diffusion Image Generation tool
Attributes:
name : Name of the tool
description : The description
args_schema : The args schema
agent_id : The agent id
resource_manager : Manages the file resources
"""
name: str = "Stable Diffusion Image Generation"
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None
agent_execution_id: int = None
resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, prompt: str, image_names: list, width: int = 512, height: int = 512, num: int = 2,
steps: int = 50):
api_key = self.get_tool_config("STABILITY_API_KEY")
if api_key is None:
return "Error: Missing Stability API key."
response = self.call_stable_diffusion(api_key, width, height, num, prompt, steps)
if response.status_code != 200:
return f"Non-200 response: {str(response.text)}"
data = response.json()
artifacts = data['artifacts']
base64_strings = []
for artifact in artifacts:
base64_strings.append(artifact['base64'])
for i in range(num):
image_base64 = base64_strings[i]
img_data = base64.b64decode(image_base64)
final_img = Image.open(BytesIO(img_data))
image_format = final_img.format
img_byte_arr = BytesIO()
final_img.save(img_byte_arr, format=image_format)
self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue())
return f"Images downloaded and saved successfully!!"
def call_stable_diffusion(self, api_key, width, height, num, prompt, steps):
engine_id = self.get_tool_config("ENGINE_ID")
if "768" in engine_id:
if height < 768:
height = 768
if width < 768:
width = 768
response = requests.post(
f"https://api.stability.ai/v1/generation/{engine_id}/text-to-image",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}"
},
json={
"text_prompts": [{"text": prompt}],
"height": height,
"width": width,
"samples": num,
"steps": steps,
},
)
return response
================================================
FILE: superagi/tools/instagram_tool/README.MD
================================================
# SuperAGI Instagram Tool
The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account.Currently will only work on the webapp
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
If you've put the correct Google API key and Custom Search Engine ID, you'll be able to use the Google Search Tool as well.
### 🔧 **Instagram tool requirements**
Since the tool uses the official instagram graph API's to post media on user accounts, There are a few requirements:
You will need access to the following:
1. An Instagram Business Account or Instagram Creator Account
2. A Facebook Page connected to that account
3. A Facebook Developer account that can perform Tasks on that Page
4. A registered Facebook App with Basic settings configured
Once everything is set up, add the meta user access token (to be generated from facebook developer account), Facebook page ID (can be found on the facebook page connected to the instagram account under 'Page transparency' in 'About' section of the page ) and the stability API key to the correspponding toolkits.
Follow the steps given in the link to set up meta requirements: (https://developers.facebook.com/docs/instagram-api/getting-started)
Follow the link to generate stability API key: (https://dreamstudio.com/api/)
### 🔧 **Configuring in SuperAGI Dashboard:**
-You can add your meta user access token and facebook ID to the Instagram Toolkit Page and stability API key to the Image Generation Toolkit Page
## Running SuperAGI Instagram Tool
Once everything has been set up just run/schedule an agent with the goal explaining the media to be published and add instagram tool (which will automatically add stable diffusion tool)
## Warning
It is advised to run the instagram tool in restricted mode since it allows you to validate the photos generated. You can schedule agent runs (recurring runs are supported as well). Also, only one photo will be posted to your account in a run. To post multiple photos use recurring runs.
================================================
FILE: superagi/tools/instagram_tool/__init__.py
================================================
================================================
FILE: superagi/tools/instagram_tool/instagram.py
================================================
import json
import urllib
import boto3
import os
from superagi.config.config import get_config
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.resource_helper import ResourceHelper
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.token_counter import TokenCounter
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
import os
import requests
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
import random
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.helper.s3_helper import S3Helper
from superagi.types.storage_types import StorageType
class InstagramSchema(BaseModel):
photo_description: str = Field(
...,
description="description of the photo",
)
filename: str = Field(..., description="Name of the file to be posted. Only one file can be posted at a time.")
class InstagramTool(BaseTool):
"""
Instagram tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "Instagram tool"
description = (
"A tool for posting an AI generated photo on Instagram"
)
args_schema: Type[InstagramSchema] = InstagramSchema
tool_response_manager: Optional[ToolResponseQueryManager] = None
agent_id:int =None
agent_execution_id:int =None
class Config:
arbitrary_types_allowed = True
def _execute(self, photo_description: str, filename: str) -> str:
"""
Execute the Instagram tool.
Args:
photo_description : description of the photo to be posted
Returns:
Image posted successfully message if image has been posted on instagram or error message.
"""
session = self.toolkit_config.session
meta_user_access_token = self.get_tool_config("META_USER_ACCESS_TOKEN")
facebook_page_id=self.get_tool_config("FACEBOOK_PAGE_ID")
if meta_user_access_token is None:
return "Error: Missing meta user access token."
if facebook_page_id is None:
return "Error: Missing facebook page id."
#create caption for the instagram
caption=self.create_caption(photo_description)
#get request for fetching the instagram_business_account_id
root_api_url="https://graph.facebook.com/v17.0/"
response=self.get_req_insta_id(root_api_url,facebook_page_id,meta_user_access_token)
if response.status_code != 200:
return f"Non-200 response: {str(response.text)}"
data = response.json()
insta_business_account_id=data["instagram_business_account"]["id"]
file_path=self.get_file_path(session, filename, self.agent_id, self.agent_execution_id)
#handling case where image generation generates multiple images
image_url,encoded_caption=self.get_img_url_and_encoded_caption(photo_description,file_path, filename)
#post request for getting the media container ID
response=self.post_media_container_id(root_api_url,insta_business_account_id,image_url,encoded_caption,meta_user_access_token)
if response.status_code != 200:
return f"Non-200 response: {str(response.text)}"
data = response.json()
container_ID=data["id"]
#post request to post the media container on instagram account
response=self.post_media(root_api_url,insta_business_account_id,container_ID,meta_user_access_token)
if response.status_code != 200:
return f"Non-200 response: {str(response.text)}"
return "Photo posted successfully!"
def create_caption(self, photo_description: str) -> str:
"""
Create a caption for the instagram post based on the photo description
Args:
photo_description : Description of the photo to be posted
Returns:
Description of the photo to be posted
"""
caption_prompt ="""Generate an instagram post caption for the following text `{photo_description}`
Attempt to make it as relevant as possible to the description and should be different and unique everytime. Add relevant emojis and hashtags."""
caption_prompt = caption_prompt.replace("{photo_description}", str(photo_description))
messages = [{"role": "system", "content": caption_prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
caption=result["content"]
encoded_caption=urllib. parse. quote(caption)
return encoded_caption
def get_file_path(self, session, file_name, agent_id, agent_execution_id):
"""
Gets the path of the image file
Args:
media_files: Name of the media files to be posted
Returns:
The path of the image file
"""
final_path = ResourceHelper().get_agent_read_resource_path(file_name,
agent=Agent.get_agent_from_id(session, agent_id),
agent_execution=AgentExecution.get_agent_execution_from_id(
session, agent_execution_id))
return final_path
def get_img_public_url(self,filename,content):
"""
Puts the image generated by image generation tool in the s3 bucket and returns the public url of the same
Args:
s3 : S3 bucket
file_path: Path of the image file in s3
content: Image file
Returns:
The public url of the image put in s3 bucket
"""
bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME")
object_key=f"instagram_upload_images/{filename}"
S3Helper(get_config("INSTAGRAM_TOOL_BUCKET_NAME")).upload_file_content(content, object_key)
image_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}"
return image_url
def get_img_url_and_encoded_caption(self,photo_description,file_path,filename):
#fetching the image from the s3 using the file_path
content = self._get_image_content(file_path)
#storing the image in a public bucket and getting the image url
image_url = self.get_img_public_url(filename,content)
#encoding the caption with possible emojis and hashtags and removing the starting and ending double quotes
encoded_caption=self.create_caption(photo_description)
print(image_url, encoded_caption)
return image_url,encoded_caption
def get_req_insta_id(self,root_api_url,facebook_page_id,meta_user_access_token):
url_to_get_acc_id=f"{root_api_url}{facebook_page_id}?fields=instagram_business_account&access_token={meta_user_access_token}"
response=requests.get(
url_to_get_acc_id
)
return response
def post_media_container_id(self,root_api_url,insta_business_account_id,image_url,encoded_caption,meta_user_access_token):
url_to_create_media_container=f"{root_api_url}{insta_business_account_id}/media?image_url={image_url}&caption={encoded_caption}&access_token={meta_user_access_token}"
response = requests.post(
url_to_create_media_container
)
return response
def post_media(self,root_api_url,insta_business_account_id,container_ID,meta_user_access_token):
url_to_post_media_container=f"{root_api_url}{insta_business_account_id}/media_publish?creation_id={container_ID}&access_token={meta_user_access_token}"
response = requests.post(
url_to_post_media_container
)
return response
def _get_image_content(self, file_path):
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
attachment_data = S3Helper().read_binary_from_s3(file_path)
else:
with open(file_path, "rb") as file:
attachment_data = file.read()
return attachment_data
================================================
FILE: superagi/tools/instagram_tool/instagram_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.instagram_tool.instagram import InstagramTool
from superagi.types.key_type import ToolConfigKeyType
class InstagramToolkit(BaseToolkit, ABC):
name: str = "Instagram Toolkit"
description: str = "Toolkit containing tools for posting AI generated photo on Instagram. Posts only one photo in a run "
def get_tools(self) -> List[BaseTool]:
return [InstagramTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="META_USER_ACCESS_TOKEN", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=True),
ToolConfiguration(key="FACEBOOK_PAGE_ID", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False)
]
================================================
FILE: superagi/tools/jira/README.MD
================================================
# SuperAGI Jira Tool
The SuperAGI Jira Tool lets users create, edit and search issues while providing a foundation for other great use cases.
## 💡 Features
1.**Create Issue:** SuperAGI's JIRA tool lets you seamlessly create new tasks in your project by defining the task's details such as its summary, description, type, and priority.
2. **Edit Issue:** Modify existing tasks quickly with SuperAGI's JIRA tool, which allows you to change any task details like summary, description, type, and priority.
3. **Search Issues:** Use the powerful 'Search Issues' feature to find specific tasks within your projects by defining your search criteria in terms of project, assignee, or keywords in the task summary.
4. **Get Projects:** Discover and access all your projects with ease using the 'Get Projects' feature, providing a bird's eye view of your workload and streamlining project-based searches.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
### 🔧 **Add Jira configuration settings in SuperAGI Dashboard**
Add the following configuration settings in the file:
1. _JIRA API TOKEN:_
- Login into your Jira Account. Go to "Manage Account".
- Go to Security and click on "Create and Manage API Tokens".
- Click on "Create API Token" and choose an appropriate label name.
- Copy the API Token and save it in a text file.
2. _JIRA INSTANCE URL:_
- Your instance profile is the section at the start of your URL. It should look something like "https://mycompany.atlassian.net/".
3. _JIRA USERNAME:_
- Your Jira UserName is the Email Address with which you signed up in Jira.
4. _CONFIGURING JIRA IN SUPERAGI DASHBOARD:_
- Open the Jira Toolkit Page in SuperAGI Add your Jira API Token, your Instance URL, and your Jira Username and click "Update Changes"
## Running SuperAGI Jira Tool
1. **Create an Issue:** The SuperAGI JIRA Create Issue tool allows you to create issues in your project. By default, it creates a task with predefined settings. To create a task with different details, modify the relevant fields in the create_issue.py script.
2. **Edit an Issue:** To edit a particular issue, specify the issue ID in your goal. The modifications can be made by changing the relevant fields in the edit_issue.py script.
3. **Search for Issues:** You can simply search for a particular issue in your agent's goals and your agent performs a search based on the JIRA Query Language (JQL) query you define. Modify the JQL query according to your requirements in the search_issues.py script.
4. **Fetch Project Details:** Use the 'Get Projects' feature to retrieve a list of your accessible projects. The get_projects.py script can be modified to adjust the parameters of this operation.
================================================
FILE: superagi/tools/jira/__init__.py
================================================
================================================
FILE: superagi/tools/jira/create_issue.py
================================================
from typing import Type
from pydantic import BaseModel, Field
from superagi.tools.jira.tool import JiraTool, JiraIssueSchema
class CreateIssueSchema(BaseModel):
fields: dict = Field(
...,
description='Dictionary of fields to create the Jira issue with. Format: {{"summary": "test issue", "project": "project_id", "description": "test description", "issuetype": {{"name": "Task"}}, "priority": {{"name": "Low"}}}}',
)
class CreateIssueTool(JiraTool):
"""
Create Jira Issue tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "CreateJiraIssue"
description = "Create a new Jira issue."
args_schema: Type[CreateIssueSchema] = CreateIssueSchema
def _execute(self, fields: dict):
"""
Execute the create issue tool.
Args: fields (dict): Dictionary of fields to create the Jira issue with. Format: {"summary": "test issue",
"project": "project_id", "description": "test description", "issuetype": {"name": "Task"}, "priority": {
"name": "Low"}}
Returns:
The success message mentioning the key of the created issue.
"""
jira = self.build_jira_instance()
new_issue = jira.create_issue(fields=fields)
return f"Issue '{new_issue.key}' created successfully!"
================================================
FILE: superagi/tools/jira/edit_issue.py
================================================
from typing import Type
from pydantic import Field, BaseModel
from superagi.tools.jira.tool import JiraTool, JiraIssueSchema
class EditIssueSchema(BaseModel):
key: str = Field(
...,
description="Issue key or id in Jira",
)
fields: dict = Field(
...,
description='Dictionary of fields to create the Jira issue with. Format: {{"summary": "test issue", "project": "project_id", "description": "test description", "issuetype": {{"name": "Task"}}, "priority": {{"name": "Low"}}}}',
)
class EditIssueTool(JiraTool):
"""
Edit Jira Issue tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "EditJiraIssue"
description = "Edit a Jira issue."
args_schema: Type[EditIssueSchema] = EditIssueSchema
def _execute(self, key: str, fields: dict):
"""
Execute the edit issue tool.
Args:
key : Issue key or id in Jira
fields (dict): Dictionary of fields to create the Jira issue with. Format: {"summary": "test issue",
"project": "project_id", "description": "test description", "issuetype": {"name": "Task"}, "priority": {
"name": "Low"}}
Returns:
The success message mentioning key of the edited issue or Issue not found!
"""
jira = self.build_jira_instance()
issues = jira.search_issues(f"key={key}")
if issues:
issues[0].update(fields=fields)
return f"Issue '{issues[0].key}' created successfully!"
return f"Issue not found!"
================================================
FILE: superagi/tools/jira/get_projects.py
================================================
from typing import Type, List
from pydantic import BaseModel, Field
from superagi.tools.jira.tool import JiraIssueSchema, JiraTool
class GetProjectsSchema(BaseModel):
pass
class GetProjectsTool(JiraTool):
"""
Get Jira Projects tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "GetJiraProjects"
description = "This tool is a wrapper around atlassian-python-api's Jira project API. Useful in fetching all the projects accessible to the user, discovering the total count of projects, or utilizing it as an interim step during project-based searches."
args_schema: Type[GetProjectsSchema] = GetProjectsSchema
def parse_projects(self, projects: List[dict]) -> List[dict]:
parsed = []
for project in projects:
parsed.append({"id": project.id, "key": project.key, "name": project.name})
return parsed
def _execute(self) -> str:
"""
Execute the get projects tool.
Returns:
Found projects:
"""
jira = self.build_jira_instance()
projects = jira.projects()
parsed_projects = self.parse_projects(projects)
parsed_projects_str = (
"Found " + str(len(parsed_projects)) + " projects:\n" + str(parsed_projects)
)
return parsed_projects_str
================================================
FILE: superagi/tools/jira/jira_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.jira.create_issue import CreateIssueTool
from superagi.tools.jira.edit_issue import EditIssueTool
from superagi.tools.jira.get_projects import GetProjectsTool
from superagi.tools.jira.search_issues import SearchJiraTool
from superagi.types.key_type import ToolConfigKeyType
from superagi.models.tool_config import ToolConfig
class JiraToolkit(BaseToolkit, ABC):
name: str = "Jira Toolkit"
description: str = "Toolkit containing tools for Jira integration"
def get_tools(self) -> List[BaseTool]:
return [
CreateIssueTool(),
EditIssueTool(),
GetProjectsTool(),
SearchJiraTool(),
]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="JIRA_INSTANCE_URL", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = False),
ToolConfiguration(key="JIRA_USERNAME", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=False),
ToolConfiguration(key="JIRA_API_TOKEN", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=True)
]
================================================
FILE: superagi/tools/jira/search_issues.py
================================================
import json
from typing import Type, Dict, List
from pydantic import Field, BaseModel
from superagi.helper.token_counter import TokenCounter
from superagi.tools.jira.tool import JiraTool
class SearchIssueSchema(BaseModel):
query: str = Field(
...,
description="JQL query string to search issues. For example, to find all the issues in project \"Test\" assigned to the me, you would pass in the following string: project = Test AND assignee = currentUser() or to find issues with summaries that contain the word \"test\", you would pass in the following string: summary ~ 'test'.",
)
class SearchJiraTool(JiraTool):
"""
Search Jira Issues tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name = "SearchJiraIssues"
description = "This tool is a wrapper around atlassian-python-api's Jira jql API, useful when you need to search for Jira issues."
args_schema: Type[SearchIssueSchema] = SearchIssueSchema
def _execute(self, query: str) -> str:
"""
Execute the search issues tool.
Args:
query : JQL query string to search issues. For example, to find all the issues in project "Test"
assigned to, you would pass in the following string: project = Test AND assignee = currentUser() or to
find issues with summaries that contain the word "test", you would pass in the following string: summary ~
'test'.
Returns:
The list of issues matching the query.
"""
jira = self.build_jira_instance()
issues = jira.search_issues(query)
parsed_issues = self.parse_issues(issues)
parsed_issues_str = (
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
)
return parsed_issues_str
def parse_issues(self, issues: List) -> List[dict]:
"""
Parse the issues returned by the Jira API.
Args:
issues : List of issues returned by the Jira API.
Returns:
List of parsed issues.
"""
parsed = []
for issue in issues:
key = issue.key
summary = issue.fields.summary
created = issue.fields.created[0:10]
priority = issue.fields.priority.name
status = issue.fields.status.name
try:
assignee = issue.fields.assignee.displayName
except Exception:
assignee = "None"
rel_issues = {}
for related_issue in issue.fields.issuelinks:
if "inwardIssue" in related_issue.keys():
rel_type = related_issue.type.inward
rel_key = related_issue.inwardIssue.key
rel_summary = related_issue.inwardIssue.fields.summary
if "outwardIssue" in related_issue.keys():
rel_type = related_issue.type.outward
rel_key = related_issue.outwardIssue.key
rel_summary = related_issue.outwardIssue.fields.summary
rel_issues = {"type": rel_type, "key": rel_key, "summary": rel_summary}
parsed.append(
{
"key": key,
"summary": summary,
"created": created,
"assignee": assignee,
"priority": priority,
"status": status,
"related_issues": rel_issues,
}
)
if TokenCounter.count_text_tokens(json.dumps(parsed)) > self.max_token_limit:
break
return parsed
================================================
FILE: superagi/tools/jira/tool.py
================================================
import os
import requests
from typing import List, Type
from pydantic import BaseModel, Field
from superagi.config.config import get_config
from superagi.tools.base_tool import BaseTool
from jira import JIRA
class JiraIssueSchema(BaseModel):
issue_key: str = Field(
...,
description="The key of the Jira issue.",
)
fields: dict = Field(
...,
description="The fields to update for the Jira issue.",
)
class JiraTool(BaseTool):
"""
Jira tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
def build_jira_instance(self) -> dict:
"""
Build a Jira instance.
Returns:
The Jira instance.
"""
jira_instance_url = self.get_tool_config("JIRA_INSTANCE_URL")
jira_username = self.get_tool_config("JIRA_USERNAME")
jira_api_token = self.get_tool_config("JIRA_API_TOKEN")
jira = JIRA(
server=jira_instance_url,
basic_auth=(jira_username, jira_api_token)
)
return jira
================================================
FILE: superagi/tools/knowledge_search/knowledge_search.py
================================================
from superagi.models.agent_config import AgentConfiguration
from superagi.models.knowledges import Knowledges
from superagi.models.vector_db_indices import VectordbIndices
from superagi.models.vector_dbs import Vectordbs
from superagi.models.vector_db_configs import VectordbConfigs
from superagi.models.toolkit import Toolkit
from superagi.vector_store.vector_factory import VectorFactory
from superagi.models.configuration import Configuration
from superagi.jobs.agent_executor import AgentExecutor
from typing import Any, Type, List
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
# from superagi.tools.file.read_file import ReadFileTool
class KnowledgeSearchSchema(BaseModel):
query: str = Field(..., description="The query to search required from knowledge search")
class KnowledgeSearchTool(BaseTool):
name: str = "Knowledge Search"
args_schema: Type[BaseModel] = KnowledgeSearchSchema
agent_id: int = None
description = (
"A tool for performing a Knowledge search on knowledge base which might have knowledge of the task you are pursuing."
"To find relevant info, use this tool first before using other tools."
"If you don't find sufficient info using Knowledge tool, you may use other tools."
"If a question is being asked, responding with context from info returned by knowledge tool is prefered."
"Input should be a search query."
)
def _execute(self, query: str):
session = self.toolkit_config.session
toolkit = session.query(Toolkit).filter(Toolkit.id == self.toolkit_config.toolkit_id).first()
organisation_id = toolkit.organisation_id
knowledge_id = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == self.agent_id, AgentConfiguration.key == "knowledge").first().value
knowledge = Knowledges.get_knowledge_from_id(session, knowledge_id)
if knowledge is None:
return "Selected Knowledge not found"
vector_db_index = VectordbIndices.get_vector_index_from_id(session, knowledge.vector_db_index_id)
vector_db = Vectordbs.get_vector_db_from_id(session, vector_db_index.vector_db_id)
db_creds = VectordbConfigs.get_vector_db_config_from_db_id(session, vector_db.id)
model_api_key = self.get_tool_config('OPENAI_API_KEY')
model_source = 'OpenAI'
embedding_model = AgentExecutor.get_embedding(model_source, model_api_key)
try:
if vector_db_index.state == "Custom":
filters = None
if vector_db_index.state == "Marketplace":
filters = {"knowledge_name": knowledge.name}
vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, vector_db_index.name, embedding_model, **db_creds)
search_result = vector_db_storage.get_matching_text(query, metadata=filters)
return f"Result: \n{search_result['search_res']}"
except Exception as err:
return f"Error fetching text: {err}"
================================================
FILE: superagi/tools/knowledge_search/knowledge_search_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.knowledge_search.knowledge_search import KnowledgeSearchTool
from superagi.types.key_type import ToolConfigKeyType
class KnowledgeSearchToolkit(BaseToolkit, ABC):
name: str = "Knowledge Search Toolkit"
description: str = "Toolkit containing tools for performing search on the knowledge base."
def get_tools(self) -> List[BaseTool]:
return [KnowledgeSearchTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True)
]
================================================
FILE: superagi/tools/resource/__init__.py
================================================
================================================
FILE: superagi/tools/resource/query_resource.py
================================================
import logging
import os
from typing import Optional
from typing import Type
import openai
from langchain.chat_models import ChatOpenAI
from llama_index import VectorStoreIndex, LLMPredictor, ServiceContext
from llama_index.vector_stores.types import ExactMatchFilter, MetadataFilters
from pydantic import BaseModel, Field
from superagi.config.config import get_config
from superagi.llms.base_llm import BaseLlm
from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
from superagi.tools.base_tool import BaseTool
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.chromadb import ChromaDB
class QueryResource(BaseModel):
"""Input for QueryResource tool."""
query: str = Field(..., description="the search query to search resources")
class QueryResourceTool(BaseTool):
"""
Read File tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
name: str = "QueryResource"
args_schema: Type[BaseModel] = QueryResource
description: str = "Tool searches resources content and extracts relevant information to perform the given task." \
"Tool is given preference over other search/read file tools for relevant data." \
"Resources content is taken from the files: {summary}"
agent_id: int = None
llm: Optional[BaseLlm] = None
def _execute(self, query: str):
openai.api_key = self.llm.get_api_key()
os.environ["OPENAI_API_KEY"] = self.llm.get_api_key()
llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=self.llm.get_model(),
openai_api_key=get_config("OPENAI_API_KEY")))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt)
vector_store_name = VectorStoreType.get_vector_store_type(
self.get_tool_config(key="RESOURCE_VECTOR_STORE") or "Redis")
vector_store_index_name = self.get_tool_config(key="RESOURCE_VECTOR_STORE_INDEX_NAME") or "super-agent-index"
logging.info(f"vector_store_name {vector_store_name}")
logging.info(f"vector_store_index_name {vector_store_index_name}")
vector_store = LlamaVectorStoreFactory(vector_store_name, vector_store_index_name).get_vector_store()
logging.info(f"vector_store {vector_store}")
as_query_engine_args = dict(
filters=MetadataFilters(
filters=[
ExactMatchFilter(
key="agent_id",
value=str(self.agent_id)
)
]
)
)
if vector_store_name == VectorStoreType.CHROMA:
as_query_engine_args["chroma_collection"] = ChromaDB.create_collection(
collection_name=vector_store_index_name)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, service_context=service_context)
query_engine = index.as_query_engine(
**as_query_engine_args
)
try:
response = query_engine.query(query)
except ValueError as e:
logging.error(f"ValueError {e}")
response = "Document not found"
return response
================================================
FILE: superagi/tools/resource/resource_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.resource.query_resource import QueryResourceTool
from superagi.types.key_type import ToolConfigKeyType
class JiraToolkit(BaseToolkit, ABC):
name: str = "Resource Toolkit"
description: str = "Toolkit containing tools for Resource integration"
def get_tools(self) -> List[BaseTool]:
return [
QueryResourceTool(),
]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="RESOURCE_VECTOR_STORE", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True),
ToolConfiguration(key="RESOURCE_VECTOR_STORE_INDEX_NAME", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret=True)
]
================================================
FILE: superagi/tools/searx/README.MD
================================================
# SuperAGI Searx Search Tool
The SuperAGI Searx Search Tool helps users perform a Searx search and extract snippets and webpages. We parse the HTML response pages because most Searx instances do not support the JSON response format without an API key.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
## Running SuperAGI Searx Search Serp Tool
You can simply ask your agent about latest information regarding anything in the world and your agent will be able to browse the internet to get that information for you.
================================================
FILE: superagi/tools/searx/__init__.py
================================================
================================================
FILE: superagi/tools/searx/search_scraper.py
================================================
import random
from typing import List
import httpx
from bs4 import BeautifulSoup
from pydantic import BaseModel
from superagi.lib.logger import logger
searx_hosts = ["https://search.ononoki.org", "https://searx.be", "https://search.us.projectsegfau.lt"]
class SearchResult(BaseModel):
"""
Represents a single search result from Searx
Attributes:
id : The ID of the search result.
title : The title of the search result.
link : The link of the search result.
description : The description of the search result.
sources : The sources of the search result.
"""
id: int
title: str
link: str
description: str
sources: List[str]
def __str__(self):
return f"""{self.id}. {self.title} - {self.link}
{self.description}"""
def search(query):
"""
Gets the raw HTML of a searx search result page
Args:
query : The query to search for.
"""
# TODO: use a better strategy for choosing hosts. Could use this list: https://searx.space/data/instances.json
searx_url = random.choice(searx_hosts)
res = httpx.get(
searx_url + "/search", params={"q": query}, headers={"User-Agent": "Mozilla/5.0 (X11; Linux i686; rv:109.0) Gecko/20100101 Firefox/114.0"}
)
if res.status_code != 200:
logger.info(res.status_code, searx_url)
raise Exception(f"Searx returned {res.status_code} status code")
return res.text
def clean_whitespace(s: str):
"""
Cleans up whitespace in a string
Args:
s : The string to clean up.
Returns:
The cleaned up string.
"""
return " ".join(s.split())
def scrape_results(html):
"""
Converts raw HTML into a list of SearchResult objects
Args:
html : The raw HTML to convert.
Returns:
A list of SearchResult objects.
"""
soup = BeautifulSoup(html, "html.parser")
result_divs = soup.find_all(attrs={"class": "result"})
result_list = []
n = 1
for result_div in result_divs:
if result_div is None:
continue
# Needed to work on multiple versions of Searx
header = result_div.find(["h4", "h3"])
if header is None:
continue
link = header.find("a")["href"]
title = header.text.strip()
description = clean_whitespace(result_div.find("p").text)
# Needed to work on multiple versions of Searx
sources_container = result_div.find(
attrs={"class": "pull-right"}
) or result_div.find(attrs={"class": "engines"})
source_spans = sources_container.find_all("span")
sources = []
for s in source_spans:
sources.append(s.text.strip())
result = SearchResult(
id=n, title=title, link=link, description=description, sources=sources
)
result_list.append(result)
n += 1
return result_list
def search_results(query):
'''Returns a text summary of the search results via the SearchResult.__str__ method'''
return "\n\n".join(list(map(lambda x: str(x), scrape_results(search(query)))))
================================================
FILE: superagi/tools/searx/searx.py
================================================
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.error_handler import ErrorHandler
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
from superagi.tools.searx.search_scraper import search_results
class SearxSearchSchema(BaseModel):
query: str = Field(
...,
description="The search query for the Searx search engine.",
)
class SearxSearchTool(BaseTool):
"""
Searx Search tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "SearxSearch"
agent_id:int =None
agent_execution_id:int =None
description = (
"A tool for performing a Searx search and extracting snippets and webpages."
"Input should be a search query."
)
args_schema: Type[SearxSearchSchema] = SearxSearchSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, query: str) -> tuple:
"""
Execute the Searx search tool.
Args:
query : The query to search for.
Returns:
Snippets from the Searx search.
"""
snippets = search_results(query)
summary = self.summarise_result(query, snippets)
return summary
def summarise_result(self, query, snippets):
"""
Summarise the result of the Searx search.
Args:
query : The query to search for.
snippets : The snippets from the Searx search.
Returns:
A summary of the result.
"""
summarize_prompt = """Summarize the following text `{snippets}`
Write a concise or as descriptive as necessary and attempt to
answer the query: `{query}` as best as possible. Use markdown formatting for
longer responses."""
summarize_prompt = summarize_prompt.replace("{snippets}", str(snippets))
summarize_prompt = summarize_prompt.replace("{query}", query)
messages = [{"role": "system", "content": summarize_prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
return result["content"]
================================================
FILE: superagi/tools/searx/searx_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from superagi.tools.searx.searx import SearxSearchTool
from superagi.types.key_type import ToolConfigKeyType
class SearxSearchToolkit(BaseToolkit, ABC):
name: str = "Searx Toolkit"
description: str = "Toolkit containing tools for performing Google search and extracting snippets and webpages " \
"using Searx"
def get_tools(self) -> List[BaseTool]:
return [SearxSearchTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return []
================================================
FILE: superagi/tools/slack/README.md
================================================
# SuperAGI Slack Toolkit
This SuperAGI Tool lets users send messages to Slack Channels and provides a strong foundation for use cases to come.
**Features:**
1. Send Message - This tool gives SuperAGI the ability to send messages to Slack Channels that you have specified
## 🛠️ Installation
### Setting up of SuperAGI:
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
### 🔧 **Slack Configuration:**
1. Create an Application on SlackAPI Portal

2. Select "from scratch"

3. Add your application's name and the workspace for which you'd like to use your Slack Application

4. Once the app creation process is done, head to the "OAuth and Permissions" tab

5. Find the “**bot token scopes”** and define the following scopes:
**"chat:write",** and save it

6. Once you've defined the scope, install the application to your workspace

7. Post installation, you will get the bot token code

8. Once the installation is done, you'll get the Bot User OAuth Token, which needs to be added to the Slack Toolkit Page

Once the configuration is complete, you can install the app in the channel of your choice and create an agent on SuperAGI which can now send messages to the Slack Channel.
================================================
FILE: superagi/tools/slack/__init__.py
================================================
================================================
FILE: superagi/tools/slack/send_message.py
================================================
from typing import Type
from pydantic import Field, BaseModel
from superagi.tools.base_tool import BaseTool
from superagi.config.config import get_config
from slack_sdk import WebClient
class SlackMessageSchema(BaseModel):
channel: str = Field(
...,
description="Slack Channel/Group Name"
)
message: str = Field(
...,
description="Text Message to be sent to a person or a group or people"
)
class SlackMessageTool(BaseTool):
"""
Slack Message Tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
This Tool works for both Individual and Group messages
- Individual Texting - Provide user-id
- Group Texting - Provide group-id
"""
name = "SendSlackMessage"
description = "Send text message in Slack"
args_schema: Type[SlackMessageSchema] = SlackMessageSchema
def _execute(self, channel: str, message: str):
"""
Execute the Slack Message Tool.
Args:
channel : The channel name.
message : The message to be sent.
Returns:
success message if message is sent successfully or failure message if message sending fails.
"""
slack = self.build_slack_web_client()
response = slack.chat_postMessage(channel=channel, text=message)
if response['ok']:
return f'Message sent to {channel} Successfully'
else:
return 'Message sending failed!'
def build_slack_web_client(self):
slack_bot_token = self.get_tool_config("SLACK_BOT_TOKEN")
return WebClient(token=slack_bot_token)
================================================
FILE: superagi/tools/slack/slack_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.slack.send_message import SlackMessageTool
from superagi.types.key_type import ToolConfigKeyType
class SlackToolkit(BaseToolkit, ABC):
name: str = "Slack Toolkit"
description: str = "Toolkit containing tools for Slack integration"
def get_tools(self) -> List[BaseTool]:
return [
SlackMessageTool(),
]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="SLACK_BOT_TOKEN", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True)
]
================================================
FILE: superagi/tools/thinking/__init__.py
================================================
================================================
FILE: superagi/tools/thinking/prompts/thinking.txt
================================================
Given the following overall objective
Objective:
{goals}
and the following task, `{task_description}`.
Below is last tool response:
`{last_tool_response}`
Below is the relevant tool response:
`{relevant_tool_response}`
Perform the task by understanding the problem, extracting variables, and being smart
and efficient. Provide a descriptive response, make decisions yourself when
confronted with choices and provide reasoning for ideas / decisions.
================================================
FILE: superagi/tools/thinking/thinking_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.thinking.tools import ThinkingTool
from superagi.types.key_type import ToolConfigKeyType
class ThinkingToolkit(BaseToolkit, ABC):
name: str = "Thinking Toolkit"
description: str = "Toolkit containing tools for intelligent problem-solving"
def get_tools(self) -> List[BaseTool]:
return [
ThinkingTool(),
]
def get_env_keys(self) -> List[ToolConfiguration]:
return []
================================================
FILE: superagi/tools/thinking/tools.py
================================================
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.helper.error_handler import ErrorHandler
from superagi.helper.prompt_reader import PromptReader
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
class ThinkingSchema(BaseModel):
task_description: str = Field(
...,
description="Task description which needs reasoning.",
)
class ThinkingTool(BaseTool):
"""
Thinking tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
llm: LLM used for thinking.
"""
llm: Optional[BaseLlm] = None
name = "ThinkingTool"
description = (
"Intelligent problem-solving assistant that comprehends tasks, identifies key variables, and makes efficient decisions, all while providing detailed, self-driven reasoning for its choices. Do not assume anything, take the details from given data only."
)
args_schema: Type[ThinkingSchema] = ThinkingSchema
goals: List[str] = []
agent_execution_id:int=None
agent_id:int = None
permission_required: bool = False
tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
arbitrary_types_allowed = True
def _execute(self, task_description: str):
"""
Execute the Thinking tool.
Args:
task_description : The task description.
Returns:
Thought process of llm for the task
"""
try:
prompt = PromptReader.read_tools_prompt(__file__, "thinking.txt")
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{task_description}", task_description)
last_tool_response = self.tool_response_manager.get_last_response()
prompt = prompt.replace("{last_tool_response}", last_tool_response)
metadata = {"agent_execution_id":self.agent_execution_id}
relevant_tool_response = self.tool_response_manager.get_relevant_response(query=task_description,metadata=metadata)
prompt = prompt.replace("{relevant_tool_response}",relevant_tool_response)
messages = [{"role": "system", "content": prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
if 'error' in result and result['message'] is not None:
ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message'])
return result["content"]
except Exception as e:
logger.error(e)
return f"Error generating text: {e}"
================================================
FILE: superagi/tools/tool_response_query_manager.py
================================================
from sqlalchemy.orm import Session
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.vector_store.base import VectorStore
class ToolResponseQueryManager:
def __init__(self, session: Session, agent_execution_id: int,memory:VectorStore):
self.session = session
self.agent_execution_id = agent_execution_id
self.memory=memory
def get_last_response(self, tool_name: str = None):
return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name)
def get_relevant_response(self, query: str,metadata:dict, top_k: int = 5):
if self.memory is None:
return ""
documents = self.memory.get_matching_text(query, metadata=metadata)
relevant_responses = ""
for document in documents["documents"]:
relevant_responses += document.text_content
return relevant_responses
================================================
FILE: superagi/tools/twitter/README.md
================================================
# SuperAGI Twitter Toolkit
Introducing Twitter Toolkit for SuperAGI. With Twitter Integrated into SuperAGI, you can now deploy agents to
1. Send Tweets
2. Send Tweets with Images
## Installation
### 🛠️ Setting up SuperAGI:
Set up SuperAGI by following the instructions given [here](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
### 🔐 Obtaining API Key and Secret from Twitter Developer Portal
1. Log in to your Twitter Developer Portal Account and select your project under the “Projects & Apps” section.

2. Proceed with creating a new app. Once you have created the app by adding a name, you will get an API Key and an API Secret, copy that and keep it in a separate text file.


### 🚪 Configuring OAuth
3. Once you have saved the key and the secret, click on “App Settings”
4. Once you are on the App Settings Page, start setting up the User Authentication Settings.

5. Fill in the details as shown in the below image. Give “Read and Write Permissions” and make it a “Web Application"

6. Add the Callback URI and the Website URL as shown in the image below

7. Save the settings. you have now configured OAuth Authentication for Twitter.
### ✅ Configuring Keys and Authenticating in SuperAGI.
1. In the SuperAGI’s Dashboard, navigate to the Twitter Toolkit Page, add the API Key and API Secret you’ve saved, and click on ‘Update Changes’

2. After you’ve updated the changes, click on Authenticate. This will take you to the OAuth Flow. Authorize the app through the flow.

Once you have followed the above steps, you have successfully integrated Twitter with SuperAGI.
================================================
FILE: superagi/tools/twitter/send_tweets.py
================================================
from typing import Type
from pydantic import BaseModel, Field
from superagi.helper.twitter_helper import TwitterHelper
from superagi.helper.twitter_tokens import TwitterTokens
from superagi.tools.base_tool import BaseTool
class SendTweetsInput(BaseModel):
tweet_text: str = Field(...,
description="Tweet text to be posted from twitter handle, if no value is given keep the default value as 'None'")
is_media: bool = Field(..., description="'True' if there is any media to be posted with Tweet else 'False'.")
media_files: list = Field(..., description="Name of the media files to be uploaded.")
class SendTweetsTool(BaseTool):
name: str = "Send Tweets Tool"
args_schema: Type[BaseModel] = SendTweetsInput
description: str = "Send and Schedule Tweets for your Twitter Handle"
agent_id: int = None
agent_execution_id: int = None
def _execute(self, is_media: bool, tweet_text: str = 'None', media_files: list = []):
toolkit_id = self.toolkit_config.toolkit_id
session = self.toolkit_config.session
creds = TwitterTokens(session).get_twitter_creds(toolkit_id)
params = {}
if is_media:
media_ids = TwitterHelper().get_media_ids(session, media_files, creds, self.agent_id,
self.agent_execution_id)
params["media"] = {"media_ids": media_ids}
if tweet_text is not None:
params["text"] = tweet_text
tweet_response = TwitterHelper().send_tweets(params, creds)
if tweet_response.status_code == 201:
return "Tweet posted successfully!!"
else:
return "Error posting tweet. (Status code: {})".format(tweet_response.status_code)
================================================
FILE: superagi/tools/twitter/twitter_toolkit.py
================================================
from abc import ABC
from superagi.tools.base_tool import BaseToolkit, BaseTool, ToolConfiguration
from typing import Type, List
from superagi.tools.twitter.send_tweets import SendTweetsTool
from superagi.types.key_type import ToolConfigKeyType
class TwitterToolkit(BaseToolkit, ABC):
name: str = "Twitter Toolkit"
description: str = "Twitter Tool kit contains all tools related to Twitter"
def get_tools(self) -> List[BaseTool]:
return [SendTweetsTool()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
ToolConfiguration(key="TWITTER_API_KEY", key_type=ToolConfigKeyType.STRING, is_required= True, is_secret = True),
ToolConfiguration(key="TWITTER_API_SECRET", key_type=ToolConfigKeyType.STRING, is_required=True, is_secret= True)
]
================================================
FILE: superagi/tools/webscaper/README.MD
================================================
# SuperAGI Web Scraper Tool
The SuperAGI Webscraper Tool lets users perform web scraping, extracting URLs and retrieving the textual content from websites.
## ⚙️ Installation
### 🛠 **Setting Up of SuperAGI**
Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
You'll be able to use the Web Scraper Tool on the fly once you have setup SuperAGI.
## Running SuperAGI Web Scraper Tool
You can simply ask your agent to read or go through a certain website or URL, and it'll be able to retrieve it's textual information from there.
================================================
FILE: superagi/tools/webscaper/__init__.py
================================================
================================================
FILE: superagi/tools/webscaper/tools.py
================================================
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.helper.webpage_extractor import WebpageExtractor
from superagi.llms.base_llm import BaseLlm
from superagi.tools.base_tool import BaseTool
class WebScraperSchema(BaseModel):
website_url: str = Field(
...,
description="Valid website url without any quotes.",
)
class WebScraperTool(BaseTool):
"""
Web Scraper tool
Attributes:
name : The name.
description : The description.
args_schema : The args schema.
"""
llm: Optional[BaseLlm] = None
name = "WebScraperTool"
description = (
"Used to scrape website urls and extract text content"
)
args_schema: Type[WebScraperSchema] = WebScraperSchema
class Config:
arbitrary_types_allowed = True
def _execute(self, website_url: str) -> tuple:
"""
Execute the Web Scraper tool.
Args:
website_url : The website url to scrape.
Returns:
The text content of the website.
"""
content = WebpageExtractor().extract_with_bs4(website_url)
max_length = len(' '.join(content.split(" ")[:600]))
return content[:max_length]
================================================
FILE: superagi/tools/webscaper/web_scraper_toolkit.py
================================================
from abc import ABC
from typing import List
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.webscaper.tools import WebScraperTool
from superagi.types.key_type import ToolConfigKeyType
class WebScrapperToolkit(BaseToolkit, ABC):
name: str = "Web Scrapper Toolkit"
description: str = "Web Scrapper tool kit is used to scrape web"
def get_tools(self) -> List[BaseTool]:
return [
WebScraperTool(),
]
def get_env_keys(self) -> List[ToolConfiguration]:
return []
================================================
FILE: superagi/types/__init__.py
================================================
================================================
FILE: superagi/types/common.py
================================================
from abc import abstractmethod
from pydantic import BaseModel, Field
class BaseMessage(BaseModel):
"""Base message object."""
content: str
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Message type used."""
class HumanMessage(BaseMessage):
"""Message by human."""
example: bool = False
@property
def type(self) -> str:
return "user"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
return "assistant"
class SystemMessage(BaseMessage):
"""Used when message is system message."""
@property
def type(self) -> str:
return "system"
class GitHubLinkRequest(BaseModel):
"""Used for Request body in install API"""
github_link: str
================================================
FILE: superagi/types/key_type.py
================================================
from enum import Enum
class ToolConfigKeyType(Enum):
STRING = 'string'
FILE = 'file'
INT = 'int'
@classmethod
def get_key_type(cls, store):
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid key type.")
def __str__(self):
return self.value
================================================
FILE: superagi/types/model_source_types.py
================================================
from enum import Enum
class ModelSourceType(Enum):
GooglePalm = 'Google Palm'
OpenAI = 'OpenAi'
Replicate = 'Replicate'
HuggingFace = 'Hugging Face'
LocalLLM = 'Local LLM'
@classmethod
def get_model_source_type(cls, name):
name = name.upper().replace(" ", "")
for member in cls.__members__:
if name == member.upper():
return cls[member]
raise ValueError(f"{name} is not a valid vector store name.")
@classmethod
def get_model_source_from_model(cls, model_name: str):
open_ai_models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
google_models = ['google-palm-bison-001', 'models/chat-bison-001']
replicate_models = ['replicate-llama13b-v2-chat']
if model_name in open_ai_models:
return ModelSourceType.OpenAI
if model_name in google_models:
return ModelSourceType.GooglePalm
if model_name in replicate_models:
return ModelSourceType.Replicate
return ModelSourceType.OpenAI
def __str__(self):
return self.value
================================================
FILE: superagi/types/queue_status.py
================================================
from enum import Enum
class QueueStatus(Enum):
INITIATED = 'INITIATED'
PROCESSING = 'PROCESSING'
COMPLETE = 'COMPLETE'
@classmethod
def get_queue_type(cls, store):
if store is None:
raise ValueError("Queue status type cannot be None.")
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid storage name.")
================================================
FILE: superagi/types/storage_types.py
================================================
from enum import Enum
class StorageType(Enum):
FILE = 'FILE'
S3 = 'S3'
@classmethod
def get_storage_type(cls, store):
if store is None:
raise ValueError("Storage type cannot be None.")
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid storage name.")
================================================
FILE: superagi/types/vector_store_types.py
================================================
from enum import Enum
class VectorStoreType(Enum):
REDIS = 'redis'
PINECONE = 'pinecone'
CHROMA = 'chroma'
WEAVIATE = 'weaviate'
QDRANT = 'qdrant'
LANCEDB = 'LanceDB'
@classmethod
def get_vector_store_type(cls, store):
store = store.upper()
if store in cls.__members__:
return cls[store]
raise ValueError(f"{store} is not a valid vector store name.")
def __str__(self):
return self.value
================================================
FILE: superagi/vector_embeddings/__init__.py
================================================
================================================
FILE: superagi/vector_embeddings/base.py
================================================
import warnings
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Tuple
class VectorEmbeddings(ABC):
@abstractmethod
def get_vector_embeddings_from_chunks(
self,
final_chunks: Any
):
""" Returns embeddings for vector dbs from final chunks"""
================================================
FILE: superagi/vector_embeddings/pinecone.py
================================================
from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings
class Pinecone(VectorEmbeddings):
def __init__(self, uuid, embeds, metadata):
self.uuid = uuid
self.embeds = embeds
self.metadata = metadata
def get_vector_embeddings_from_chunks(self):
""" Returns embeddings for vector dbs from final chunks"""
result = {}
vectors = list(zip(self.uuid, self.embeds, self.metadata))
result['vectors'] = vectors
return result
================================================
FILE: superagi/vector_embeddings/qdrant.py
================================================
from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings
class Qdrant(VectorEmbeddings):
def __init__(self, uuid, embeds, metadata):
self.uuid = uuid
self.embeds = embeds
self.metadata = metadata
def get_vector_embeddings_from_chunks(self):
""" Returns embeddings for vector dbs from final chunks"""
result = {}
result['ids'] = self.uuid
result['payload'] = self.metadata
result['vectors'] = self.embeds
return result
================================================
FILE: superagi/vector_embeddings/vector_embedding_factory.py
================================================
import pinecone
from typing import Optional
from pinecone import UnauthorizedException
from superagi.vector_embeddings.pinecone import Pinecone
from superagi.vector_embeddings.qdrant import Qdrant
from superagi.vector_embeddings.weaviate import Weaviate
from superagi.types.vector_store_types import VectorStoreType
class VectorEmbeddingFactory:
@classmethod
def build_vector_storage(cls, vector_store: VectorStoreType, chunk_json: Optional[dict] = None):
"""
Get the vector embeddings from final chunks.
Args:
vector_store : The vector store name.
Returns:
The vector storage object
"""
final_chunks = []
uuid = []
embeds = []
metadata = []
vector_store = VectorStoreType.get_vector_store_type(vector_store)
if chunk_json is not None:
for key in chunk_json.keys():
final_chunks.append(chunk_json[key])
for i in range(0, len(final_chunks)):
uuid.append(final_chunks[i]["id"])
embeds.append(final_chunks[i]["embeds"])
data = {
'text': final_chunks[i]['text'],
'chunk': final_chunks[i]['chunk'],
'knowledge_name': final_chunks[i]['knowledge_name']
}
metadata.append(data)
if vector_store == VectorStoreType.PINECONE:
return Pinecone(uuid, embeds, metadata)
if vector_store == VectorStoreType.QDRANT:
return Qdrant(uuid, embeds, metadata)
if vector_store == VectorStoreType.WEAVIATE:
return Weaviate(uuid, embeds, metadata)
================================================
FILE: superagi/vector_embeddings/weaviate.py
================================================
from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings
class Weaviate(VectorEmbeddings):
def __init__(self, uuid, embeds, metadata):
self.uuid = uuid
self.embeds = embeds
self.metadata = metadata
def get_vector_embeddings_from_chunks(self):
""" Returns embeddings for vector dbs from final chunks"""
return {'ids': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds}
================================================
FILE: superagi/vector_store/__init__.py
================================================
================================================
FILE: superagi/vector_store/base.py
================================================
import warnings
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Tuple
from superagi.vector_store.document import Document
class VectorStore(ABC):
@abstractmethod
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Add texts to the vector store."""
@abstractmethod
def get_matching_text(self, query: str, top_k: int, metadata: Optional[dict], **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using specified search type."""
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
"""
texts = [doc.text_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
@abstractmethod
def get_index_stats(self) -> dict:
"""Returns stats or information of an index"""
@abstractmethod
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
"""Add embeddings to the vector store."""
@abstractmethod
def delete_embeddings_from_vector_db(self,ids: List[str]) -> None:
"""Delete embeddings from the vector store."""
================================================
FILE: superagi/vector_store/chromadb.py
================================================
import uuid
from typing import Any, Optional, Iterable, List
import chromadb
from chromadb import Settings
from superagi.config.config import get_config
from superagi.vector_store.base import VectorStore
from superagi.vector_store.document import Document
from superagi.vector_store.embedding.base import BaseEmbedding
def _build_chroma_client():
chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
chroma_port = get_config("CHROMA_PORT") or 8000
return chromadb.Client(Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
chroma_server_http_port=chroma_port))
class ChromaDB(VectorStore):
def __init__(
self,
collection_name: str,
embedding_model: BaseEmbedding,
text_field: str,
namespace: Optional[str] = "",
):
self.client = _build_chroma_client()
self.collection_name = collection_name
self.embedding_model = embedding_model
self.text_field = text_field
self.namespace = namespace
@classmethod
def create_collection(cls, collection_name):
"""Create a Chroma Collection.
Args:
collection_name: The name of the collection to create.
"""
chroma_client = _build_chroma_client()
return chroma_client.get_or_create_collection(name=collection_name)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
namespace: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> List[str]:
"""Add texts to the vector store."""
if namespace is None:
namespace = self.namespace
metadatas = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
if len(ids) < len(texts):
raise ValueError("Number of ids must match number of texts.")
for text, id in zip(texts, ids):
metadata = metadatas.pop(0) if metadatas else {}
metadata[self.text_field] = text
metadatas.append(metadata)
collection = self.client.get_collection(name=self.collection_name)
collection.add(
documents=texts,
metadatas=metadatas,
ids=ids
)
return ids
def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict] = {}, **kwargs: Any) -> List[
Document]:
"""Return docs most similar to query using specified search type."""
embedding_vector = self.embedding_model.get_embedding(query)
collection = self.client.get_collection(name=self.collection_name)
filters = {}
for key in metadata.keys():
filters[key] = metadata[key]
results = collection.query(
query_embeddings=embedding_vector,
include=["documents"],
n_results=top_k,
where=filters
)
documents = []
for node_id, text, metadata in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0]):
documents.append(
Document(
text_content=text,
metadata=metadata
)
)
return documents
def get_index_stats(self) -> dict:
pass
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
pass
def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
pass
================================================
FILE: superagi/vector_store/document.py
================================================
from pydantic import BaseModel, Field
class Document(BaseModel):
"""Interface for interacting with a document."""
text_content: str = None
metadata: dict = Field(default_factory=dict)
def __init__(self, text_content, *args, **kwargs):
super().__init__(text_content=text_content, *args, **kwargs)
================================================
FILE: superagi/vector_store/embedding/__init__.py
================================================
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.embedding.palm import PalmEmbedding
__all__ = ['OpenAiEmbedding', 'PalmEmbedding']
================================================
FILE: superagi/vector_store/embedding/base.py
================================================
from abc import ABC, abstractmethod
class BaseEmbedding(ABC):
@abstractmethod
def get_embedding(self, text):
pass
================================================
FILE: superagi/vector_store/embedding/openai.py
================================================
import openai
class OpenAiEmbedding:
def __init__(self, api_key, model="text-embedding-ada-002"):
self.model = model
self.api_key = api_key
async def get_embedding_async(self, text: str):
try:
openai.api_key = self.api_key
response = await openai.Embedding.create(
input=[text],
engine=self.model
)
return response['data'][0]['embedding']
except Exception as exception:
return {"error": exception}
def get_embedding(self, text):
try:
# openai.api_key = get_config("OPENAI_API_KEY")
response = openai.Embedding.create(
api_key=self.api_key,
input=[text],
engine=self.model
)
return response['data'][0]['embedding']
except Exception as exception:
return {"error": exception}
================================================
FILE: superagi/vector_store/embedding/palm.py
================================================
import openai
import google.generativeai as palm
class PalmEmbedding:
def __init__(self, api_key, model="models/embedding-gecko-001"):
self.model = model
self.api_key = api_key
def get_embedding(self, text):
try:
response = palm.generate_embeddings(model=self.model, text=text)
return response['embedding']
except Exception as exception:
return {"error": exception}
================================================
FILE: superagi/vector_store/pinecone.py
================================================
import uuid
from superagi.vector_store.document import Document
from superagi.vector_store.base import VectorStore
from typing import Any, Callable, Optional, Iterable, List
from superagi.vector_store.embedding.base import BaseEmbedding
class Pinecone(VectorStore):
"""
Pinecone vector store.
Attributes:
index : The pinecone index.
embedding_model : The embedding model.
text_field : The text field is the name of the field where the corresponding text for an embedding is stored.
namespace : The namespace.
"""
def __init__(
self,
index: Any,
embedding_model: Optional[Any] = None,
text_field: Optional[str] = 'text',
namespace: Optional[str] = '',
):
try:
import pinecone
except ImportError:
raise ValueError("Please install pinecone to use this vector store.")
if not isinstance(index, pinecone.index.Index):
raise ValueError("Please provide a valid pinecone index.")
self.index = index
self.embedding_model = embedding_model
self.text_field = text_field
self.namespace = namespace
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
namespace: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> list[str]:
"""
Add texts to the vector store.
Args:
texts : The texts to add.
metadatas : The metadatas to add.
ids : The ids to add.
namespace : The namespace to add.
batch_size : The batch size to add.
**kwargs : The keyword arguments to add.
Returns:
The list of ids vectors stored in pinecone.
"""
if namespace is None:
namespace = self.namespace
vectors = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
if len(ids) < len(texts):
raise ValueError("Number of ids must match number of texts.")
for text, id in zip(texts, ids):
metadata = metadatas.pop(0) if metadatas else {}
metadata[self.text_field] = text
vectors.append((id, self.embedding_model.get_embedding(text), metadata))
self.add_embeddings_to_vector_db({"vectors": vectors})
return ids
def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict] = None, **kwargs: Any) -> List[Document]:
"""
Return docs most similar to query using specified search type.
Args:
query : The query to search.
top_k : The top k to search.
**kwargs : The keyword arguments to search.
Returns:
The list of documents most similar to the query
"""
namespace = kwargs.get("namespace", self.namespace)
filters = {}
if metadata is not None:
for key in metadata.keys():
filters[key] = {"$eq": metadata[key]}
embed_text = self.embedding_model.get_embedding(query)
res = self.index.query(embed_text, filter=filters, top_k=top_k, namespace=namespace,include_metadata=True)
search_res = self._get_search_text(res, query)
documents = self._build_documents(res)
return {"documents": documents, "search_res": search_res}
def get_index_stats(self) -> dict:
"""
Returns:
Stats or Information about an index
"""
index_stats = self.index.describe_index_stats()
dimensions = index_stats.dimension
vector_count = index_stats.total_vector_count
return {"dimensions": dimensions, "vector_count": vector_count}
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
"""Upserts embeddings to the given vector store"""
try:
self.index.upsert(vectors=embeddings['vectors'])
except Exception as err:
raise err
def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
"""Deletes embeddings from the given vector store"""
try:
self.index.delete(ids=ids)
except Exception as err:
raise err
def _build_documents(self, results: List[dict]):
try:
documents = []
for doc in results['matches']:
documents.append(
Document(
text_content=doc['metadata'][self.text_field],
metadata=doc['metadata'],
)
)
return documents
except Exception as err:
raise err
def _get_search_text(self, results: List[dict], query: str):
contexts = [item['metadata']['text'] for item in results['matches']]
i = 0
search_res = f"Query: {query}\n"
for context in contexts:
search_res += f"Chunk{i}: \n{context}\n"
i += 1
return search_res
================================================
FILE: superagi/vector_store/qdrant.py
================================================
from __future__ import annotations
import uuid
from mimetypes import common_types
from typing import Any, Dict, Iterable, List, Optional, Tuple, Sequence, Union
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.conversions import common_types
from qdrant_client.models import Distance, VectorParams
from superagi.vector_store.base import VectorStore
from superagi.vector_store.document import Document
from superagi.config.config import get_config
DictFilter = Dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
def create_qdrant_client(api_key: Optional[str] = None, url: Optional[str] = None, port: Optional[int] = None
) -> QdrantClient:
if api_key is None:
qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
qdrant_port = get_config("QDRANT_PORT") or 6333
qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
else:
qdrant_client = QdrantClient(api_key=api_key, url=url, port=port)
return qdrant_client
class Qdrant(VectorStore):
"""
Qdrant vector store.
Attributes:
client : The Qdrant client.
embedding_model : The embedding model.
collection_name : The Qdrant collection.
text_field_payload_key : Name of the field where the corresponding text for point is stored in the collection.
metadata_payload_key : Name of the field where the corresponding metadata for point is stored in the collection.
"""
TEXT_FIELD_KEY = "text"
METADATA_KEY = "metadata"
def __init__(
self,
client: QdrantClient,
embedding_model: Optional[Any] = None,
collection_name: str = None,
text_field_payload_key: str = TEXT_FIELD_KEY,
metadata_payload_key: str = METADATA_KEY,
):
self.client = client
self.embedding_model = embedding_model
self.collection_name = collection_name
self.text_field_payload_key = text_field_payload_key or self.TEXT_FIELD_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
def add_texts(
self,
input_texts: Iterable[str],
metadata_list: Optional[List[dict]] = None,
id_list: Optional[Sequence[str]] = None,
batch_limit: int = 64,
) -> List[str]:
"""
Add texts to the vector store.
Args:
input_texts : The texts to add.
metadata_list : The metadatas to add.
id_list : The ids to add.
batch_limit : The batch size to add.
Returns:
The list of ids vectors stored in Qdrant.
"""
collected_ids = []
metadata_list = metadata_list or []
id_list = id_list or [uuid.uuid4().hex for _ in input_texts]
num_batches = len(input_texts) // batch_limit + (len(input_texts) % batch_limit != 0)
for i in range(num_batches):
text_batch = input_texts[i * batch_limit: (i + 1) * batch_limit]
metadata_batch = metadata_list[i * batch_limit: (i + 1) * batch_limit] or None
id_batch = id_list[i * batch_limit: (i + 1) * batch_limit]
vectors = self.__get_embeddings(text_batch)
payloads = self.__build_payloads(
text_batch,
metadata_batch,
self.text_field_payload_key,
self.metadata_payload_key,
)
self.add_embeddings_to_vector_db({"ids": id_batch, "vectors": vectors, "payloads": payloads})
collected_ids.extend(id_batch)
return collected_ids
def get_matching_text(
self,
text: str = None,
embedding: List[float] = None,
k: int = 4,
metadata: Optional[dict] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> Dict:
"""
Return docs most similar to query using specified search type.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return.
text : The text to search.
filter: Filter by metadata. (Please refer https://qdrant.tech/documentation/concepts/filtering/)
search_params: Additional search params
offset: Offset of the first result to return.
score_threshold: Define a minimal score threshold for the result.
consistency: Read consistency of the search. Defines how many replicas
should be queried before returning the result.
**kwargs : The keyword arguments to search.
Returns:
The list of documents most similar to the query
"""
if embedding is not None and text is not None:
raise ValueError("Only provide embedding or text")
if text is not None:
embedding = self.__get_embeddings(text)[0]
if metadata is not None:
filter_conditions = []
for key, value in metadata.items():
metadata_filter = {}
metadata_filter["key"] = key
metadata_filter["match"] = {"value": value}
filter_conditions.append(metadata_filter)
filter = models.Filter(
must = filter_conditions
)
try:
results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
query_filter=filter,
search_params=search_params,
limit=k,
offset=offset,
with_payload=True,
with_vectors=False,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
except Exception as err:
raise err
search_res = self._get_search_res(results, text)
documents = self.__build_documents(results)
return {"documents": documents, "search_res": search_res}
def get_index_stats(self) -> dict:
"""
Returns:
Stats or Information about a collection
"""
collection_info = self.client.get_collection(collection_name=self.collection_name)
dimensions = collection_info.config.params.vectors.size
vector_count = collection_info.vectors_count
return {"dimensions": dimensions, "vector_count": vector_count}
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
"""Upserts embeddings to the given vector store"""
try:
self.client.upsert(
collection_name=self.collection_name,
points=models.Batch(
ids=embeddings["ids"],
vectors=embeddings["vectors"],
payloads=embeddings["payload"]
),
)
except Exception as err:
raise err
def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
"""Deletes embeddings from the given vector store"""
try:
self.client.delete(
collection_name=self.collection_name,
points_selector = models.PointIdsList(
points = ids
),
)
except Exception as err:
raise err
def __get_embeddings(
self,
texts: Iterable[str]
) -> List[List[float]]:
"""Return embeddings for a list of texts using the embedding model."""
if self.embedding_model is not None:
query_vectors = []
for text in texts:
query_vector = self.embedding_model.get_embedding(text)
query_vectors.append(query_vector)
else:
raise ValueError("Embedding model is not set")
return query_vectors
def __build_payloads(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]],
text_field_payload_key: str,
metadata_payload_key: str,
) -> List[dict]:
"""
Builds and returns a list of payloads containing text and
corresponding metadata for each text in the input iterable.
"""
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"One or more of the text entries is set to None. "
"Ensure to eliminate these before invoking the .add_texts method on the Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
text_field_payload_key: text,
metadata_payload_key: metadata,
}
)
return payloads
def __build_documents(
self,
results: List[Dict]
) -> List[Document]:
"""Return the document version corresponding to each result."""
documents = []
for result in results:
documents.append(
Document(
text_content=result.payload.get(self.text_field_payload_key),
metadata=(result.payload.get(self.metadata_payload_key)) or {},
)
)
return documents
@classmethod
def create_collection(cls,
client: QdrantClient,
collection_name: str,
size: int
):
"""
Create a new collection in Qdrant if it does not exist.
Args:
client : The Qdrant client.
collection_name: The name of the collection to create.
size: The size for the new collection.
"""
if not any(collection.name == collection_name for collection in client.get_collections().collections):
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=size, distance=Distance.COSINE),
)
def _get_search_res(self, results, text):
contexts = [res.payload for res in results]
i = 0
search_res = f"Query: {text}\n"
for context in contexts:
search_res += f"Chunk{i}: \n{context['text']}\n"
i += 1
return search_res
================================================
FILE: superagi/vector_store/redis.py
================================================
import json
import re
import uuid
from typing import Any, List, Iterable, Mapping
from typing import Optional, Pattern
import traceback
import numpy as np
import redis
from redis.commands.search.field import TagField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.vector_store.base import VectorStore
from superagi.vector_store.document import Document
DOC_PREFIX = "doc:"
CONTENT_KEY = "content"
METADATA_KEY = "metadata"
VECTOR_SCORE_KEY = "vector_score"
class Redis(VectorStore):
def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
pass
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
pass
def get_index_stats(self) -> dict:
pass
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
def __init__(self, index: Any, embedding_model: Any):
"""
Args:
index: An instance of a Redis index.
embedding_model: An instance of a BaseEmbedding model.
vector_group_id: vector group id used to index similar vectors.
"""
redis_url = get_config('REDIS_URL')
self.redis_client = redis.Redis.from_url("redis://" + redis_url + "/0", decode_responses=True)
# self.redis_client = redis.Redis(host=redis_host, port=redis_port)
self.index = index
self.embedding_model = embedding_model
self.content_key = "content",
self.metadata_key = "metadata"
self.index = index
self.vector_key = "content_vector"
def build_redis_key(self, prefix: str) -> str:
"""Build a redis key with a prefix."""
return f"{prefix}:{uuid.uuid4().hex}"
def add_texts(self, texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
embeddings: Optional[List[List[float]]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any) -> List[str]:
pipe = self.redis_client.pipeline()
prefix = DOC_PREFIX + str(self.index)
keys = []
for i, text in enumerate(texts):
id = ids[i] if ids else self.build_redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
embedding = self.embedding_model.get_embedding(text)
embedding_arr = np.array(embedding, dtype=np.float32)
pipe.hset(id, mapping={CONTENT_KEY: text, self.vector_key: embedding_arr.tobytes(),
METADATA_KEY: json.dumps(metadata)})
keys.append(id)
pipe.execute()
return keys
def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict] = None, **kwargs: Any) -> List[Document]:
embed_text = self.embedding_model.get_embedding(query)
from redis.commands.search.query import Query
hybrid_fields = self._convert_to_redis_filters(metadata)
base_query = f"{hybrid_fields}=>[KNN {top_k} @{self.vector_key} $vector AS vector_score]"
return_fields = [METADATA_KEY,CONTENT_KEY, "vector_score",'id']
query = (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, top_k)
.dialect(2)
)
params_dict: Mapping[str, str] = {
"vector": np.array(embed_text)
.astype(dtype=np.float32)
.tobytes()
}
# print(self.index)
results = self.redis_client.ft(self.index).search(query,params_dict)
# Prepare document results
documents = []
for result in results.docs:
documents.append(
Document(
text_content=result.content,
metadata=json.loads(result.metadata)
)
)
return {"documents": documents}
def _convert_to_redis_filters(self, metadata: Optional[dict] = None) -> str:
if metadata is not None or len(metadata) == 0:
return "*"
filter_strings = []
for key in metadata.keys():
filter_string = "@%s:{%s}" % (key, self.escape_token(str(metadata[key])))
filter_strings.append(filter_string)
joined_filter_strings = " & ".join(filter_strings)
return f"({joined_filter_strings})"
def create_index(self):
try:
# check to see if index exists
temp = self.redis_client.ft(self.index).info()
logger.info(temp)
logger.info("Index already exists!")
except:
vector_dimensions = self.embedding_model.get_embedding("sample")
# schema
schema = (
TagField("tag"), # Tag Field Name
VectorField(self.vector_key, # Vector Field Name
"FLAT", { # Vector Index Type: FLAT or HNSW
"TYPE": "FLOAT32", # FLOAT32 or FLOAT64
"DIM": len(vector_dimensions), # Number of Vector Dimensions
"DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric
}
)
)
# index Definition
definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH)
# create Index
self.redis_client.ft(self.index).create_index(fields=schema, definition=definition)
def escape_token(self, value: str) -> str:
"""
Escape punctuation within an input string. Taken from RedisOM Python.
Args:
value (str): The input string.
Returns:
str: The escaped string.
"""
escaped_chars_re = re.compile(Redis.DEFAULT_ESCAPED_CHARS)
def escape_symbol(match: re.Match) -> str:
return f"\\{match.group(0)}"
return escaped_chars_re.sub(escape_symbol, value)
================================================
FILE: superagi/vector_store/vector_factory.py
================================================
import pinecone
from pinecone import UnauthorizedException
from superagi.vector_store.pinecone import Pinecone
from superagi.vector_store import weaviate
from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store import qdrant
from superagi.vector_store.redis import Redis
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.qdrant import Qdrant
class VectorFactory:
@classmethod
def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding_model):
"""
Get the vector storage.
Args:
vector_store : The vector store name.
index_name : The index name.
embedding_model : The embedding model.
Returns:
The vector storage object.
"""
if isinstance(vector_store, str):
vector_store = VectorStoreType.get_vector_store_type(vector_store)
if vector_store == VectorStoreType.PINECONE:
try:
api_key = get_config("PINECONE_API_KEY")
env = get_config("PINECONE_ENVIRONMENT")
if api_key is None or env is None:
raise ValueError("PineCone API key not found")
pinecone.init(api_key=api_key, environment=env)
if index_name not in pinecone.list_indexes():
sample_embedding = embedding_model.get_embedding("sample")
if "error" in sample_embedding:
logger.error(f"Error in embedding model {sample_embedding}")
# if does not exist, create index
pinecone.create_index(
index_name,
dimension=len(sample_embedding),
metric='dotproduct'
)
index = pinecone.Index(index_name)
return Pinecone(index, embedding_model, 'text')
except UnauthorizedException:
raise ValueError("PineCone API key not found")
if vector_store == VectorStoreType.WEAVIATE:
use_embedded = get_config("WEAVIATE_USE_EMBEDDED")
url = get_config("WEAVIATE_URL")
api_key = get_config("WEAVIATE_API_KEY")
client = weaviate.create_weaviate_client(
use_embedded=use_embedded,
url=url,
api_key=api_key
)
return weaviate.Weaviate(client, embedding_model, index_name, 'text')
if vector_store == VectorStoreType.QDRANT:
client = qdrant.create_qdrant_client()
sample_embedding = embedding_model.get_embedding("sample")
if "error" in sample_embedding:
logger.error(f"Error in embedding model {sample_embedding}")
Qdrant.create_collection(client, index_name, len(sample_embedding))
return qdrant.Qdrant(client, embedding_model, index_name)
if vector_store == VectorStoreType.REDIS:
index_name = "super-agent-index1"
redis = Redis(index_name, embedding_model)
redis.create_index()
return redis
raise ValueError(f"Vector store {vector_store} not supported")
@classmethod
def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding_model = None, **creds):
if isinstance(vector_store, str):
vector_store = VectorStoreType.get_vector_store_type(vector_store)
if vector_store == VectorStoreType.PINECONE:
try:
pinecone.init(api_key = creds["api_key"], environment = creds["environment"])
index = pinecone.Index(index_name)
return Pinecone(index, embedding_model)
except UnauthorizedException:
raise ValueError("PineCone API key not found")
if vector_store == VectorStoreType.QDRANT:
try:
client = qdrant.create_qdrant_client(creds["api_key"], creds["url"], creds["port"])
return qdrant.Qdrant(client, embedding_model, index_name)
except:
raise ValueError("Qdrant API key not found")
if vector_store == VectorStoreType.WEAVIATE:
try:
client = weaviate.create_weaviate_client(creds["url"], creds["api_key"])
return weaviate.Weaviate(client, embedding_model, index_name)
except:
raise ValueError("Weaviate API key not found")
================================================
FILE: superagi/vector_store/weaviate.py
================================================
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Tuple
import weaviate
from uuid import uuid4
from superagi.vector_store.base import VectorStore
from superagi.vector_store.document import Document
def create_weaviate_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> weaviate.Client:
"""
Creates a Weaviate client instance.
Args:
use_embedded: Whether to use the embedded Weaviate instance. Defaults to True.
url: The URL of the Weaviate instance to connect to. Required if `use_embedded` is False.
api_key: The API key to use for authentication if using Weaviate Cloud Services. Optional.
Returns:
A Weaviate client instance.
Raises:
ValueError: If invalid argument combination are passed.
"""
if url:
if api_key:
auth_config = weaviate.AuthApiKey(api_key=api_key)
else:
auth_config = None
client = weaviate.Client(url=url, auth_client_secret=auth_config)
else:
raise ValueError("Invalid arguments passed to create_weaviate_client")
return client
class Weaviate(VectorStore):
def __init__(
self, client: weaviate.Client, embedding_model: Any, class_name: str, text_field: str = "text"
):
self.class_name = class_name
self.embedding_model = embedding_model
self.text_field = text_field
self.client = client
def add_texts(
self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any
) -> List[str]:
result = {}
collected_ids = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
data_object = metadata.copy()
data_object[self.text_field] = text
vector = self.embedding_model.get_embedding(text)
id = str(uuid4())
result = {"ids": id, "data_object": data_object, "vectors": vector}
collected_ids.append(id)
self.add_embeddings_to_vector_db(result)
return collected_ids
def get_matching_text(
self, query: str, top_k: int = 5, metadata: dict = None, **kwargs: Any
) -> List[Document]:
metadata_fields = self._get_metadata_fields()
query_vector = self.embedding_model.get_embedding(query)
if metadata is not None:
for key, value in metadata.items():
filters = {
"path": [key],
"operator": "Equal",
"valueString": value
}
results = self.client.query.get(
self.class_name,
metadata_fields + [self.text_field],
).with_near_vector(
{"vector": query_vector, "certainty": 0.7}
).with_where(filters).with_limit(top_k).do()
results_data = results["data"]["Get"][self.class_name]
search_res = self._get_search_res(results_data, query)
documents = self._build_documents(results_data, metadata_fields)
return {"search_res": search_res, "documents": documents}
def _get_metadata_fields(self) -> List[str]:
schema = self.client.schema.get(self.class_name)
property_names = []
for property_schema in schema["properties"]:
property_names.append(property_schema["name"])
property_names.remove(self.text_field)
return property_names
def get_index_stats(self) -> dict:
result = self.client.query.aggregate(self.class_name).with_meta_count().do()
vector_count = result['data']['Aggregate'][self.class_name][0]['meta']['count']
return {'vector_count': vector_count}
def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
try:
with self.client.batch as batch:
for i in range(len(embeddings['ids'])):
data_object = {key: value for key, value in embeddings['data_object'][i].items()}
batch.add_data_object(data_object, class_name=self.class_name, uuid=embeddings['ids'][i], vector=embeddings['vectors'][i])
except Exception as err:
raise err
def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
try:
for id in ids:
self.client.data_object.delete(
uuid = id,
class_name = self.class_name
)
except Exception as err:
raise err
def _build_documents(self, results_data, metadata_fields) -> List[Document]:
documents = []
for result in results_data:
text_content = result[self.text_field]
metadata = {}
for field in metadata_fields:
metadata[field] = result[field]
document = Document(text_content=text_content, metadata=metadata)
documents.append(document)
return documents
def _get_search_res(self, results, query):
text = [item['text'] for item in results]
search_res = f"Query: {query}\n"
i = 0
for context in text:
search_res += f"Chunk{i}: \n{context}\n"
i += 1
return search_res
================================================
FILE: superagi/worker.py
================================================
from __future__ import absolute_import
import sys
from sqlalchemy.orm import sessionmaker
from superagi.helper.tool_helper import handle_tools_import
from superagi.lib.logger import logger
from datetime import timedelta
from celery import Celery
from superagi.config.config import get_config
from superagi.helper.agent_schedule_helper import AgentScheduleHelper
from superagi.models.configuration import Configuration
from superagi.models.agent import Agent
from superagi.models.db import connect_db
from superagi.types.model_source_types import ModelSourceType
from sqlalchemy import event
from superagi.models.agent_execution import AgentExecution
from superagi.helper.webhook_manager import WebHookManager
redis_url = get_config('REDIS_URL', 'super__redis:6379')
app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"])
app.conf.broker_url = "redis://" + redis_url + "/0"
app.conf.result_backend = "redis://" + redis_url + "/0"
app.conf.worker_concurrency = 10
app.conf.accept_content = ['application/x-python-serialize', 'application/json']
beat_schedule = {
'initialize-schedule-agent': {
'task': 'initialize-schedule-agent',
'schedule': timedelta(minutes=5),
},
'execute_waiting_workflows': {
'task': 'execute_waiting_workflows',
'schedule': timedelta(minutes=2),
},
}
app.conf.beat_schedule = beat_schedule
@event.listens_for(AgentExecution.status, "set")
def agent_status_change(target, val,old_val,initiator):
if not hasattr(sys, '_called_from_test'):
webhook_callback.delay(target.id,val,old_val)
@app.task(name="execute_waiting_workflows", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def execute_waiting_workflows():
"""Check if wait time of wait workflow step is over and can be resumed."""
from superagi.jobs.agent_executor import AgentExecutor
logger.info("Executing waiting workflows job")
AgentExecutor().execute_waiting_workflows()
@app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def initialize_schedule_agent_task():
"""Executing agent scheduling in the background."""
schedule_helper = AgentScheduleHelper()
schedule_helper.update_next_scheduled_time()
schedule_helper.run_scheduled_agents()
@app.task(name="execute_agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def execute_agent(agent_execution_id: int, time):
"""Execute an agent step in background."""
from superagi.jobs.agent_executor import AgentExecutor
handle_tools_import()
logger.info("Execute agent:" + str(time) + "," + str(agent_execution_id))
AgentExecutor().execute_next_step(agent_execution_id=agent_execution_id)
@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle')
def summarize_resource(agent_id: int, resource_id: int):
"""Summarize a resource in background."""
from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.types.storage_types import StorageType
from superagi.models.resource import Resource
from superagi.resource_manager.resource_manager import ResourceManager
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
agent_config = Agent.fetch_configuration(session, agent_id)
organisation = Agent.find_org_by_agent_id(session, agent_id)
model_source = Configuration.fetch_configurations(session, organisation.id, "model_source", agent_config["model"]) or "OpenAi"
if ModelSourceType.GooglePalm.value in model_source or ModelSourceType.Replicate.value in model_source:
return
resource = session.query(Resource).filter(Resource.id == resource_id).first()
file_path = resource.path
if resource.storage_type == StorageType.S3.value:
documents = ResourceManager(str(agent_id)).create_llama_document_s3(file_path)
else:
documents = ResourceManager(str(agent_id)).create_llama_document(file_path)
logger.info("Summarize resource:" + str(agent_id) + "," + str(resource_id))
resource_summarizer = ResourceSummarizer(session=session, agent_id=agent_id, model=agent_config["model"])
resource_summarizer.add_to_vector_store_and_create_summary(resource_id=resource_id,
documents=documents)
session.close()
@app.task(name="webhook_callback", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle')
def webhook_callback(agent_execution_id,val,old_val):
engine = connect_db()
Session = sessionmaker(bind=engine)
with Session() as session:
WebHookManager(session).agent_status_change_callback(agent_execution_id, val, old_val)
================================================
FILE: test.py
================================================
import argparse
from datetime import datetime
from time import time
from superagi.lib.logger import logger
from sqlalchemy.orm import sessionmaker
from superagi.worker import execute_agent
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.db import connect_db
from superagi.models.organisation import Organisation
from superagi.models.project import Project
parser = argparse.ArgumentParser(description='Create a new agent.')
parser.add_argument('--name', type=str, help='Agent name for the script.')
parser.add_argument('--description', type=str, help='Agent description for the script.')
parser.add_argument('--goals', type=str, nargs='+', help='Agent goals for the script.')
args = parser.parse_args()
agent_name = args.name
agent_description = args.description
agent_goals = args.goals
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
def ask_user_for_goals():
goals = []
while True:
goal = input("Enter a goal (or 'q' to quit): ")
if goal == 'q':
break
goals.append(goal)
return goals
def run_superagi_cli(agent_name=None, agent_description=None, agent_goals=None):
# Create default organization
organization = Organisation(name='Default Organization', description='Default organization description')
session.add(organization)
session.flush() # Flush pending changes to generate the agent's ID
session.commit()
logger.info(organization)
# Create default project associated with the organization
project = Project(name='Default Project', description='Default project description',
organisation_id=organization.id)
session.add(project)
session.flush() # Flush pending changes to generate the agent's ID
session.commit()
logger.info(project)
# Agent
if agent_name is None:
agent_name = input("Enter agent name: ")
if agent_description is None:
agent_description = input("Enter agent description: ")
agent = Agent(name=agent_name, description=agent_description, project_id=project.id)
session.add(agent)
session.flush()
session.commit()
logger.info(agent)
# Agent Config
# Create Agent Configuration
agent_config_values = {
"goal": ask_user_for_goals() if agent_goals is None else agent_goals,
"agent_type": "Type Non-Queue",
"constraints": ["~4000 word limit for short term memory. ",
"Your short term memory is short, so immediately save important information to files.",
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
"No user assistance",
"Exclusively use the commands listed in double quotes e.g. \"command name\""
],
"tools": [],
"exit": "Default",
"iteration_interval": 0,
"model": "gpt-4",
"permission_type": "Default",
"LTM_DB": "Pinecone",
"memory_window": 10
}
agent_configurations = [
AgentConfiguration(agent_id=agent.id, key=key, value=str(value))
for key, value in agent_config_values.items()
]
session.add_all(agent_configurations)
session.commit()
logger.info("Agent Config : ")
logger.info(agent_configurations)
# Create agent execution in RUNNING state associated with the agent
execution = AgentExecution(status='RUNNING', agent_id=agent.id, last_execution_time=datetime.utcnow())
session.add(execution)
session.commit()
logger.info("Final Execution")
logger.info(execution)
execute_agent.delay(execution.id, datetime.now())
run_superagi_cli(agent_name=agent_name, agent_description=agent_description, agent_goals=agent_goals)
================================================
FILE: test_main.http
================================================
# Test your FastAPI endpoints
GET http://127.0.0.1:8000/
Accept: application/json
###
GET http://127.0.0.1:8000/hello/User
Accept: application/json
###
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/integration_tests/__init__.py
================================================
================================================
FILE: tests/integration_tests/vector_embeddings/__init__.py
================================================
================================================
FILE: tests/integration_tests/vector_embeddings/test_pinecone.py
================================================
import unittest
from superagi.vector_embeddings.pinecone import Pinecone
class TestPinecone(unittest.TestCase):
def setUp(self):
self.uuid = ["id1", "id2"]
self.embeds = ["embed1", "embed2"]
self.metadata = ["metadata1", "metadata2"]
self.pinecone_instance = Pinecone(self.uuid, self.embeds, self.metadata)
def test_init(self):
self.assertEqual(self.pinecone_instance.uuid, self.uuid)
self.assertEqual(self.pinecone_instance.embeds, self.embeds)
self.assertEqual(self.pinecone_instance.metadata, self.metadata)
def test_get_vector_embeddings_from_chunks(self):
expected = {
'vectors': list(zip(self.uuid, self.embeds, self.metadata))
}
result = self.pinecone_instance.get_vector_embeddings_from_chunks()
self.assertEqual(result, expected)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/integration_tests/vector_embeddings/test_qdrant.py
================================================
import unittest
from superagi.vector_embeddings.qdrant import Qdrant
class TestQdrant(unittest.TestCase):
def setUp(self):
self.uuid = ['1234', '5678']
self.embeds = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
self.metadata = [{'key1': 'value1'}, {'key2': 'value2'}]
self.qdrant_obj = Qdrant(self.uuid, self.embeds, self.metadata)
def test_init(self):
self.assertEqual(self.qdrant_obj.uuid, self.uuid)
self.assertEqual(self.qdrant_obj.embeds, self.embeds)
self.assertEqual(self.qdrant_obj.metadata, self.metadata)
def test_get_vector_embeddings_from_chunks(self):
expected = {
'ids': self.uuid,
'payload': self.metadata,
'vectors': self.embeds,
}
result = self.qdrant_obj.get_vector_embeddings_from_chunks()
self.assertEqual(result, expected)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/integration_tests/vector_embeddings/test_weaviate.py
================================================
import unittest
from superagi.vector_embeddings.base import VectorEmbeddings
from superagi.vector_embeddings.weaviate import Weaviate
class TestWeaviate(unittest.TestCase):
def setUp(self):
self.weaviate = Weaviate(uuid="1234", embeds=[0.1, 0.2, 0.3, 0.4], metadata={"info": "sample data"})
def test_init(self):
self.assertEqual(self.weaviate.uuid, "1234")
self.assertEqual(self.weaviate.embeds, [0.1, 0.2, 0.3, 0.4])
self.assertEqual(self.weaviate.metadata, {"info": "sample data"})
def test_get_vector_embeddings_from_chunks(self):
expected_result = {
"ids": "1234",
"data_object": {"info": "sample data"},
"vectors": [0.1, 0.2, 0.3, 0.4]
}
self.assertEqual(self.weaviate.get_vector_embeddings_from_chunks(), expected_result)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/integration_tests/vector_store/__init__.py
================================================
================================================
FILE: tests/integration_tests/vector_store/test_qdrant.py
================================================
import pytest
import numpy as np
from superagi.vector_store import qdrant
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from qdrant_client.models import Distance, VectorParams
from qdrant_client import QdrantClient
@pytest.fixture
def client():
client = QdrantClient(":memory:")
yield client
@pytest.fixture
def mock_openai_embedding(monkeypatch):
monkeypatch.setattr(
OpenAiEmbedding,
"get_embedding",
lambda self, text: np.random.random(3).tolist(),
)
@pytest.fixture
def store(client, mock_openai_embedding):
client.create_collection(
collection_name="Test_collection",
vectors_config=VectorParams(size=3, distance=Distance.COSINE),
)
yield qdrant.Qdrant(client, OpenAiEmbedding(api_key="test_api_key"), "Test_collection")
client.delete_collection("Test_collection")
def test_add_texts(store):
car_companies = [
"Rolls-Royce",
"Bentley",
"Ferrari",
"Lamborghini",
"Aston Martin",
"Porsche",
"Bugatti",
"Maserati",
"McLaren",
"Mercedes-Benz"
]
assert len(store.add_texts(car_companies)) == len(car_companies)
def test_get_matching_text(store):
car_companies = [
"Rolls-Royce",
"Bentley",
"Ferrari",
"Lamborghini",
"Aston Martin",
"Porsche",
"Bugatti",
"Maserati",
"McLaren",
"Mercedes-Benz"
]
store.add_texts(car_companies)
assert len(store.get_matching_text(k=2, text="McLaren")) == 2
================================================
FILE: tests/integration_tests/vector_store/test_weaviate.py
================================================
import unittest
from unittest.mock import Mock, patch, call, MagicMock
from superagi.vector_store.weaviate import create_weaviate_client, Weaviate, Document
class TestWeaviateClient(unittest.TestCase):
@patch('weaviate.Client')
@patch('weaviate.AuthApiKey')
def test_create_weaviate_client(self, MockAuth, MockClient):
# Test when url and api_key are provided
auth_instance = MockAuth.return_value
MockClient.return_value = 'client'
self.assertEqual(create_weaviate_client('url', 'api_key'), 'client')
MockAuth.assert_called_once_with(api_key='api_key')
MockClient.assert_called_once_with(url='url', auth_client_secret=auth_instance)
with self.assertRaises(ValueError):
create_weaviate_client() # Raises an error if no url is provided
class TestWeaviate(unittest.TestCase):
def setUp(self):
# create a new mock object for the client.batch attribute with the required methods for a context manager.
mock_batch = MagicMock()
mock_batch.__enter__.return_value = mock_batch
mock_batch.__exit__.return_value = None
self.client = Mock()
self.client.batch = mock_batch
self.embedding_model = Mock()
self.weaviateVectorStore = Weaviate(self.client, self.embedding_model, 'class_name', 'text_field')
def test_get_matching_text(self):
self.client.query.get.return_value.with_near_vector.return_value.with_where.return_value.with_limit.return_value.do.return_value = {'data': {'Get': {'class_name': []}}}
self.embedding_model.get_embedding.return_value = 'vector'
self.weaviateVectorStore._get_metadata_fields = Mock(return_value=['field1', 'field2'])
self.weaviateVectorStore._get_search_res = Mock(return_value='search_res')
self.weaviateVectorStore._build_documents = Mock(return_value=['document1', 'document2'])
self.assertEqual(self.weaviateVectorStore.get_matching_text('query', metadata={'field1': 'value'})
, {'search_res': 'search_res', 'documents': ['document1', 'document2']})
self.embedding_model.get_embedding.assert_called_once_with('query')
def test_add_texts(self):
self.embedding_model.get_embedding.return_value = 'vector'
self.weaviateVectorStore.add_embeddings_to_vector_db = Mock()
texts = ['text1', 'text2']
result = self.weaviateVectorStore.add_texts(texts)
self.assertEqual(len(result), 2) # We expect to get 2 IDs.
self.assertTrue(isinstance(result[0], str)) # The IDs should be strings.
self.embedding_model.get_embedding.assert_has_calls([call(texts[0]), call(texts[1])])
self.assertEqual(self.weaviateVectorStore.add_embeddings_to_vector_db.call_count, 2)
def test_add_embeddings_to_vector_db(self):
embeddings = {'ids': ['id1', 'id2'], 'data_object': [{'field': 'value1'}, {'field': 'value2'}], 'vectors': ['v1', 'v2']}
self.weaviateVectorStore.add_embeddings_to_vector_db(embeddings)
calls = [call.add_data_object({'field': 'value1'}, class_name='class_name', uuid='id1', vector='v1'),
call.add_data_object({'field': 'value2'}, class_name='class_name', uuid='id2', vector='v2')]
self.client.batch.assert_has_calls(calls)
def test_delete_embeddings_from_vector_db(self):
# You need to setup appropriate return values from the Weaviate client
self.weaviateVectorStore.delete_embeddings_from_vector_db(['id1', 'id2'])
self.client.data_object.delete.assert_called()
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/tools/google_calendar/create_event_test.py
================================================
import unittest
from unittest.mock import MagicMock, patch
from pydantic import ValidationError
from datetime import datetime, timedelta
from superagi.tools.google_calendar.create_calendar_event import CreateEventCalendarInput, CreateEventCalendarTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
class TestCreateEventCalendarInput(unittest.TestCase):
def test_create_event_calendar_input_valid(self):
input_data = {
"event_name": "Test Event",
"description": "A test event.",
"start_date": "2022-01-01",
"start_time": "12:00:00",
"end_date": "2022-01-01",
"end_time": "13:00:00",
"attendees": ["test@example.com"],
"location": "London"
}
try:
CreateEventCalendarInput(**input_data)
except ValidationError:
self.fail("ValidationError raised with valid input_data")
def test_create_event_calendar_input_invalid(self):
input_data = {
"event_name": "Test Event",
"description": "A test event.",
"start_date": "2022-99-99",
"start_time": "12:60:60",
"end_date": "2022-99-99",
"end_time": "13:60:60",
"attendees": ["test@example.com"],
"location": "London"
}
with self.assertRaises(ValidationError):
CreateEventCalendarInput(**input_data)
class TestCreateEventCalendarTool(unittest.TestCase):
def setUp(self):
self.create_event_tool = CreateEventCalendarTool()
@patch.object(GoogleCalendarCreds, "get_credentials")
@patch.object(CalendarDate, "create_event_dates")
def test_execute(self, mock_create_event_dates, mock_get_credentials):
mock_get_credentials.return_value = {
"success": True,
"service": MagicMock()
}
mock_date_utc = {
"start_datetime_utc": (datetime.utcnow() + timedelta(hours=1)).isoformat(),
"end_datetime_utc": (datetime.utcnow() + timedelta(hours=2)).isoformat(),
"timeZone": "UTC"
}
mock_create_event_dates.return_value = mock_date_utc
mock_service = MagicMock()
mock_service.events.return_value = MagicMock()
output_str_expected = f"Event Test Event at {mock_date_utc['start_datetime_utc']} created successfully, link for the event {'https://somerandomlink'}"
output_str = self.create_event_tool._execute("Test Event", "A test event", ["test@example.com"], start_date="2022-01-01", start_time="12:00:00", end_date="2022-01-01", end_time="13:00:00", location="London")
self.assertEqual(output_str, output_str_expected)
event = {
"summary": "Test Event",
"description": "A test event",
"start": {
"dateTime": mock_date_utc["start_datetime_utc"],
"timeZone": mock_date_utc["timeZone"]
},
"end": {
"dateTime": mock_date_utc["end_datetime_utc"],
"timeZone": mock_date_utc["timeZone"]
},
"attendees": [{"email": "test@example.com"}],
"location": "London"
}
mock_get_credentials.assert_called_once()
mock_create_event_dates.assert_called_once_with(mock_service, "2022-01-01", "12:00:00", "2022-01-01", "13:00:00")
mock_service.events().insert.assert_called_once_with(calendarId="primary", body=event, conferenceDataVersion=1)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/tools/google_calendar/delete_event_test.py
================================================
import unittest
from unittest.mock import Mock, patch
from pydantic import ValidationError
from superagi.tools.google_calendar.delete_calendar_event import DeleteCalendarEventInput, DeleteCalendarEventTool
class TestDeleteCalendarEventInput(unittest.TestCase):
def test_valid_input(self):
input_data = {"event_id": "123456"}
input_obj = DeleteCalendarEventInput(**input_data)
self.assertEqual(input_obj.event_id, "123456")
def test_invalid_input(self):
input_data = {"event_id": ""}
with self.assertRaises(ValidationError):
DeleteCalendarEventInput(**input_data)
class TestDeleteCalendarEventTools(unittest.TestCase):
def setUp(self):
self.delete_tool = DeleteCalendarEventTool()
@patch("your_module.GoogleCalendarCreds")
def test_execute_delete_event_with_valid_id(self, mock_google_calendar_creds):
credentials_obj = Mock()
credentials_obj.get_credentials.return_value = {"success": True, "service": Mock()}
mock_google_calendar_creds.return_value = credentials_obj
self.assertEqual(self.delete_tool._execute("123456"), "Event Successfully deleted from your Google Calendar")
@patch("your_module.GoogleCalendarCreds")
def test_execute_delete_event_with_no_id(self, mock_google_calendar_creds):
self.assertEqual(self.delete_tool._execute("None"), "Add Event ID to delete an event from Google Calendar")
@patch("your_module.GoogleCalendarCreds")
def test_execute_delete_event_with_no_credentials(self, mock_google_calendar_creds):
credentials_obj = Mock()
credentials_obj.get_credentials.return_value = {"success": False}
mock_google_calendar_creds.return_value = credentials_obj
self.assertEqual(self.delete_tool._execute("123456"), "Kindly connect to Google Calendar")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/tools/google_calendar/event_details_test.py
================================================
import unittest
from unittest.mock import MagicMock, patch
from pydantic import ValidationError
from superagi.tools.google_calendar.event_details_calendar import EventDetailsCalendarInput, EventDetailsCalendarTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
class TestEventDetailsCalendarInput(unittest.TestCase):
def test_invalid_input(self):
with self.assertRaises(ValidationError):
EventDetailsCalendarInput(event_id=None)
def test_valid_input(self):
input_data = EventDetailsCalendarInput(event_id="test_event_id")
self.assertEqual(input_data.event_id, "test_event_id")
class TestEventDetailsCalendarTool(unittest.TestCase):
def setUp(self):
self.tool = EventDetailsCalendarTool()
def test_no_credentials(self):
with patch.object(GoogleCalendarCreds, 'get_credentials') as mock_get_credentials:
mock_get_credentials.return_value = {"success": False}
result = self.tool._execute(event_id="test_event_id")
self.assertEqual(result, "Kindly connect to Google Calendar")
def test_no_event_id(self):
with patch.object(GoogleCalendarCreds, 'get_credentials') as mock_get_credentials:
mock_get_credentials.return_value = {"success": True}
result = self.tool._execute(event_id="None")
self.assertEqual(result, "Add Event ID to fetch details of an event from Google Calendar")
def test_valid_event(self):
event_data = {
'summary': 'Test Meeting',
'start': {'dateTime': '2022-01-01T09:00:00'},
'end': {'dateTime': '2022-01-01T10:00:00'},
'attendees': [{'email': 'attendee1@example.com'},
{'email': 'attendee2@example.com'}]
}
with patch.object(GoogleCalendarCreds, 'get_credentials') as mock_get_credentials:
with patch('your_module.base64.b64decode') as mock_b64decode:
mock_get_credentials.return_value = {"success": True, "service": MagicMock()}
service = mock_get_credentials.return_value["service"]
service.events().get.return_value.execute.return_value = event_data
mock_b64decode.return_value.decode.return_value = "decoded_event_id"
result = self.tool._execute(event_id="test_event_id")
mock_b64decode.assert_called_once_with("test_event_id")
service.events().get.assert_called_once_with(calendarId="primary", eventId="decoded_event_id")
expected_output = ("Event details for the event id 'test_event_id' is - \n"
"Summary : Test Meeting\n"
"Start Date and Time : 2022-01-01T09:00:00\n"
"End Date and Time : 2022-01-01T10:00:00\n"
"Attendees : attendee1@example.com,attendee2@example.com")
self.assertEqual(result, expected_output)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/tools/google_calendar/list_events_test.py
================================================
import unittest
from datetime import datetime
from unittest.mock import MagicMock, patch
from pydantic import ValidationError
from superagi.tools.google_calendar.list_calendar_events import ListCalendarEventsInput, ListCalendarEventsTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
class TestListCalendarEventsInput(unittest.TestCase):
def test_valid_input(self):
input_data = {
"start_time": "20:00:00",
"start_date": "2022-11-10",
"end_date": "2022-11-11",
"end_time": "22:00:00",
}
try:
ListCalendarEventsInput(**input_data)
validation_passed = True
except ValidationError:
validation_passed = False
self.assertEqual(validation_passed, True)
def test_invalid_input(self):
input_data = {
"start_time": "invalid time",
"start_date": "invalid date",
"end_date": "another invalid date",
"end_time": "another invalid time",
}
with self.assertRaises(ValidationError):
ListCalendarEventsInput(**input_data)
class TestListCalendarEventsTool(unittest.TestCase):
@patch.object(GoogleCalendarCreds, 'get_credentials')
@patch.object(CalendarDate, 'get_date_utc')
def test_without_events(self, mock_get_date_utc, mock_get_credentials):
tool = ListCalendarEventsTool()
mock_get_credentials.return_value = {
"success": True,
"service": MagicMock()
}
mock_service = mock_get_credentials()["service"]
mock_service.events().list().execute.return_value = {}
mock_get_date_utc.return_value = {
'start_datetime_utc': datetime.now().isoformat(),
'end_datetime_utc': datetime.now().isoformat()
}
result = tool._execute('20:00:00', '2022-11-10', '2022-11-11', '22:00:00')
self.assertEqual(result, "No events found for the given date and time range.")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/unit_tests/__init__.py
================================================
================================================
FILE: tests/unit_tests/agent/__init__.py
================================================
================================================
FILE: tests/unit_tests/agent/test_agent_iteration_step_handler.py
================================================
from unittest.mock import Mock, patch, MagicMock
import pytest
from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler
from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.agent.output_handler import ToolOutputHandler
from superagi.agent.task_queue import TaskQueue
from superagi.agent.tool_builder import ToolBuilder
from superagi.config.config import get_config
from superagi.helper.token_counter import TokenCounter
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.tools.code.write_code import CodingTool
from superagi.tools.resource.query_resource import QueryResourceTool
from superagi.tools.thinking.tools import ThinkingTool
# Given
@pytest.fixture
def test_handler():
mock_session = Mock()
llm = Mock()
agent_id = 1
agent_execution_id = 1
# Creating an instance of the class to test
handler = AgentIterationStepHandler(mock_session, llm, agent_id, agent_execution_id)
return handler
def test_build_agent_prompt(test_handler, mocker):
# Arrange
iteration_workflow = IterationWorkflow(has_task_queue=True)
agent_config = {'constraints': 'Test constraint'}
agent_execution_config = {'goal': 'Test goal', 'instruction': 'Test instruction'}
prompt = 'Test prompt'
task_queue = TaskQueue(queue_name='Test queue')
agent_tools = []
mocker.patch.object(AgentPromptBuilder, 'replace_main_variables', return_value='Test prompt')
mocker.patch.object(AgentPromptBuilder, 'replace_task_based_variables', return_value='Test prompt')
mocker.patch.object(task_queue, 'get_last_task_details', return_value={"task": "last task", "response": "last response"})
mocker.patch.object(task_queue, 'get_first_task', return_value='Test task')
mocker.patch.object(task_queue, 'get_tasks', return_value=[])
mocker.patch.object(task_queue, 'get_completed_tasks', return_value=[])
mocker.patch.object(TokenCounter, 'token_limit', return_value=1000)
mocker.patch('superagi.agent.agent_iteration_step_handler.get_config', return_value=600)
# Act
test_handler.task_queue = task_queue
result_prompt = test_handler._build_agent_prompt(iteration_workflow, agent_config, agent_execution_config,
prompt, agent_tools)
# Assert
assert result_prompt == 'Test prompt'
AgentPromptBuilder.replace_main_variables.assert_called_once_with(prompt, agent_execution_config["goal"],
agent_execution_config["instruction"],
agent_config["constraints"], agent_tools, False)
AgentPromptBuilder.replace_task_based_variables.assert_called_once()
task_queue.get_last_task_details.assert_called_once()
task_queue.get_first_task.assert_called_once()
task_queue.get_tasks.assert_called_once()
task_queue.get_completed_tasks.assert_called_once()
TokenCounter.token_limit.assert_called_once()
def test_build_tools(test_handler, mocker):
# Arrange
agent_config = {'model': 'gpt-3', 'tools': [1, 2, 3], 'resource_summary': True}
agent_execution_config = {'goal': 'Test goal', 'instruction': 'Test instruction', 'tools':[1]}
mocker.patch.object(AgentConfiguration, 'get_model_api_key', return_value={'api_key':'test_api_key','provider':'test_provider'})
mocker.patch.object(ToolBuilder, 'build_tool')
mocker.patch.object(ToolBuilder, 'set_default_params_tool', return_value=ThinkingTool())
mocker.patch.object(ResourceSummarizer, 'fetch_or_create_agent_resource_summary', return_value=True)
mocker.patch('superagi.models.tool.Tool')
test_handler.session.query.return_value.filter.return_value.all.return_value = [ThinkingTool()]
# Act
agent_tools = test_handler._build_tools(agent_config, agent_execution_config)
# Assert
assert isinstance(agent_tools[0], ThinkingTool)
assert ToolBuilder.build_tool.call_count == 1
assert ToolBuilder.set_default_params_tool.call_count == 3
assert AgentConfiguration.get_model_api_key.call_count == 1
assert ResourceSummarizer.fetch_or_create_agent_resource_summary.call_count == 1
def test_handle_wait_for_permission(test_handler, mocker):
# Arrange
mock_agent_execution = mocker.Mock(spec=AgentExecution)
mock_agent_execution.status = "WAITING_FOR_PERMISSION"
mock_iteration_workflow_step = mocker.Mock(spec=IterationWorkflowStep)
mock_iteration_workflow_step.next_step_id = 123
agent_config = {'model': 'gpt-3', 'tools': [1, 2, 3]}
agent_execution_config = {'goal': 'Test goal', 'instruction': 'Test instruction'}
mock_permission = mocker.Mock(spec=AgentExecutionPermission)
mock_permission.status = "APPROVED"
mock_permission.user_feedback = "Test feedback"
mock_permission.tool_name = "Test tool"
test_handler._build_tools = Mock(return_value=[ThinkingTool()])
test_handler.session.query.return_value.filter.return_value.first.return_value = mock_permission
# AgentExecutionPermission.filter.return_value.first.return_value = mock_permission
mock_tool_output = mocker.MagicMock()
mock_tool_output.result = "Test result"
ToolOutputHandler.handle_tool_response = Mock(return_value=mock_tool_output)
# Act
result = test_handler._handle_wait_for_permission(
mock_agent_execution, agent_config, agent_execution_config, mock_iteration_workflow_step)
# Assert
test_handler._build_tools.assert_called_once_with(agent_config, agent_execution_config)
ToolOutputHandler.handle_tool_response.assert_called_once()
assert mock_agent_execution.status == "RUNNING"
assert result
================================================
FILE: tests/unit_tests/agent/test_agent_message_builder.py
================================================
import pytest
from unittest.mock import patch, Mock
from superagi.agent.agent_message_builder import AgentLlmMessageBuilder
from superagi.models.agent_execution_feed import AgentExecutionFeed
@patch('superagi.helper.token_counter.TokenCounter.token_limit')
@patch('superagi.config.config.get_config')
def test_build_agent_messages(mock_get_config, mock_token_limit):
mock_session = Mock()
llm = Mock()
llm_model = Mock()
agent_id = 1
agent_execution_id = 1
prompt = "start"
agent_feeds = []
completion_prompt = "end"
# Mocking
mock_token_limit.return_value = 1000
mock_get_config.return_value = 600
builder = AgentLlmMessageBuilder(mock_session, llm, llm_model, agent_id, agent_execution_id)
messages = builder.build_agent_messages(prompt, agent_feeds, history_enabled=True, completion_prompt=completion_prompt)
# Test prompt message
assert messages[0] == {"role": "system", "content": prompt}
# Test initial feeds
assert mock_session.add.call_count == len(messages)
assert mock_session.commit.call_count == len(messages)
# Check if AgentExecutionFeed object is created and added to session
for i in range(len(messages)):
args, _ = mock_session.add.call_args_list[i]
feed_obj = args[0]
assert isinstance(feed_obj, AgentExecutionFeed)
assert feed_obj.agent_execution_id == agent_execution_id
assert feed_obj.agent_id == agent_id
assert feed_obj.feed == messages[i]["content"]
assert feed_obj.role == messages[i]["role"]
@patch('superagi.models.agent_execution_config.AgentExecutionConfiguration.fetch_value')
@patch('superagi.models.agent_execution_config.AgentExecutionConfiguration.add_or_update_agent_execution_config')
@patch('superagi.agent.agent_message_builder.AgentLlmMessageBuilder._build_prompt_for_recursive_ltm_summary_using_previous_ltm_summary')
@patch('superagi.agent.agent_message_builder.AgentLlmMessageBuilder._build_prompt_for_ltm_summary')
@patch('superagi.helper.token_counter.TokenCounter.count_text_tokens')
@patch('superagi.helper.token_counter.TokenCounter.token_limit')
def test_build_ltm_summary(mock_token_limit, mock_count_text_tokens, mock_build_prompt_for_ltm_summary,
mock_build_prompt_for_recursive_ltm_summary, mock_add_or_update_agent_execution_config,
mock_fetch_value):
mock_session = Mock()
llm = Mock()
llm_model = Mock()
agent_id = 1
agent_execution_id = 1
builder = AgentLlmMessageBuilder(mock_session, llm, llm_model, agent_id, agent_execution_id)
past_messages = [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}]
output_token_limit = 100
mock_token_limit.return_value = 1000
mock_count_text_tokens.return_value = 200
mock_build_prompt_for_ltm_summary.return_value = "ltm_summary_prompt"
mock_build_prompt_for_recursive_ltm_summary.return_value = "recursive_ltm_summary_prompt"
mock_fetch_value.return_value = Mock(value="ltm_summary")
llm.chat_completion.return_value = {"content": "ltm_summary"}
ltm_summary = builder._build_ltm_summary(past_messages, output_token_limit)
assert ltm_summary == "ltm_summary"
mock_add_or_update_agent_execution_config.assert_called_once()
llm.chat_completion.assert_called_once_with([{"role": "system", "content": "You are GPT Prompt writer"},
{"role": "assistant", "content": "ltm_summary_prompt"}])
@patch('superagi.helper.prompt_reader.PromptReader.read_agent_prompt')
def test_build_prompt_for_ltm_summary(mock_read_agent_prompt):
mock_session = Mock()
llm = Mock()
llm_model = Mock()
agent_id = 1
agent_execution_id = 1
builder = AgentLlmMessageBuilder(mock_session, llm, llm_model, agent_id, agent_execution_id)
past_messages = [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}]
token_limit = 100
mock_read_agent_prompt.return_value = "{past_messages}\n{char_limit}"
prompt = builder._build_prompt_for_ltm_summary(past_messages, token_limit)
assert "user: Hello\nassistant: Hi\n" in prompt
assert "400" in prompt
@patch('superagi.helper.prompt_reader.PromptReader.read_agent_prompt')
def test_build_prompt_for_recursive_ltm_summary_using_previous_ltm_summary(mock_read_agent_prompt):
mock_session = Mock()
llm = Mock()
llm_model = Mock()
agent_id = 1
agent_execution_id = 1
builder = AgentLlmMessageBuilder(mock_session, llm, llm_model, agent_id, agent_execution_id)
previous_ltm_summary = "Summary"
past_messages = [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}]
token_limit = 100
mock_read_agent_prompt.return_value = "{previous_ltm_summary}\n{past_messages}\n{char_limit}"
prompt = builder._build_prompt_for_recursive_ltm_summary_using_previous_ltm_summary(previous_ltm_summary, past_messages, token_limit)
assert "Summary" in prompt
assert "user: Hello\nassistant: Hi\n" in prompt
assert "400" in prompt
================================================
FILE: tests/unit_tests/agent/test_agent_prompt_builder.py
================================================
from unittest.mock import Mock
from unittest.mock import patch
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.tools.base_tool import BaseTool
def test_add_list_items_to_string():
items = ['item1', 'item2', 'item3']
result = AgentPromptBuilder.add_list_items_to_string(items)
assert result == '1. item1\n2. item2\n3. item3\n'
def test_clean_prompt():
prompt = ' some text with extra spaces '
result = AgentPromptBuilder.clean_prompt(prompt)
assert result == 'some text with extra spaces'
@patch('superagi.agent.agent_prompt_builder.AgentPromptBuilder.add_list_items_to_string')
@patch('superagi.agent.agent_prompt_builder.AgentPromptBuilder.add_tools_to_prompt')
def test_replace_main_variables(mock_add_tools_to_prompt, mock_add_list_items_to_string):
super_agi_prompt = "{goals} {instructions} {task_instructions} {constraints} {tools}"
goals = ['goal1', 'goal2']
instructions = ['instruction1']
constraints = ['constraint1']
tools = [Mock(spec=BaseTool)]
# Mocking
mock_add_list_items_to_string.side_effect = lambda x: ', '.join(x)
mock_add_tools_to_prompt.return_value = 'tools_str'
result = AgentPromptBuilder.replace_main_variables(super_agi_prompt, goals, instructions, constraints, tools)
assert 'goal1, goal2 INSTRUCTION' in result
assert 'instruction1' in result
assert 'constraint1' in result
@patch('superagi.agent.agent_prompt_builder.TokenCounter.count_message_tokens')
def test_replace_task_based_variables(mock_count_message_tokens):
super_agi_prompt = "{current_task} {last_task} {last_task_result} {pending_tasks} {completed_tasks} {task_history}"
current_task = "task1"
last_task = "task2"
last_task_result = "result1"
pending_tasks = ["task3", "task4"]
completed_tasks = [{'task': 'task1', 'response': 'response1'}, {'task': 'task2', 'response': 'response2'}]
token_limit = 2000
# Mocking
mock_count_message_tokens.return_value = 50
result = AgentPromptBuilder.replace_task_based_variables(super_agi_prompt, current_task, last_task, last_task_result,
pending_tasks, completed_tasks, token_limit)
expected_result = f"{current_task} {last_task} {last_task_result} {str(pending_tasks)} {str([x['task'] for x in completed_tasks])} \nTask: {completed_tasks[-1]['task']}\nResult: {completed_tasks[-1]['response']}\nTask: {completed_tasks[-2]['task']}\nResult: {completed_tasks[-2]['response']}\n"
assert result == expected_result
@patch('superagi.agent.agent_prompt_builder.TokenCounter.count_message_tokens')
def test_replace_task_based_variables(mock_count_message_tokens):
super_agi_prompt = "{current_task} {last_task} {last_task_result} {pending_tasks} {completed_tasks} {task_history}"
current_task = "task1"
last_task = "task2"
last_task_result = "result1"
pending_tasks = ["task3", "task4"]
completed_tasks = [{'task': 'task1', 'response': 'response1'}, {'task': 'task2', 'response': 'response2'}]
token_limit = 2000
# Mocking
mock_count_message_tokens.return_value = 50
result = AgentPromptBuilder.replace_task_based_variables(super_agi_prompt, current_task, last_task, last_task_result,
pending_tasks, completed_tasks, token_limit)
# expected_result = f"{current_task} {last_task} {last_task_result} {str(pending_tasks)} {str([x['task'] for x in reversed(completed_tasks)])} \nTask: {completed_tasks[-1]['task']}\nResult: {completed_tasks[-1]['response']}\nTask: {completed_tasks[-2]['task']}\nResult: {completed_tasks[-2]['response']}\n"
assert "task1" in result
assert "task2" in result
assert "result1" in result
assert "task3" in result
assert "task3" in result
assert "response1" in result
assert "response2" in result
================================================
FILE: tests/unit_tests/agent/test_agent_prompt_template.py
================================================
import pytest
from unittest.mock import patch, mock_open
from superagi.agent.agent_prompt_template import AgentPromptTemplate
from superagi.helper.prompt_reader import PromptReader
@patch("builtins.open", new_callable=mock_open, read_data="test_prompt")
def test_get_super_agi_single_prompt(mock_file):
expected_result = {"prompt": "test_prompt", "variables": ["goals", "instructions", "constraints", "tools"]}
result = AgentPromptTemplate.get_super_agi_single_prompt()
assert result == expected_result
@patch("builtins.open", new_callable=mock_open, read_data="test_prompt")
def test_start_task_based(mock_file):
expected_result = {"prompt": "test_prompt", "variables": ["goals", "instructions"]}
result = AgentPromptTemplate.start_task_based()
assert result == expected_result
@patch("builtins.open", new_callable=mock_open, read_data="test_prompt")
def test_analyse_task(mock_file):
expected_result = {"prompt": "test_prompt",
"variables": ["goals", "instructions", "tools", "current_task"]}
result = AgentPromptTemplate.analyse_task()
assert result == expected_result
@patch("builtins.open", new_callable=mock_open, read_data="test_prompt")
def test_create_tasks(mock_file):
expected_result = {"prompt": "test_prompt", "variables": ["goals", "instructions", "last_task", "last_task_result", "pending_tasks"]}
result = AgentPromptTemplate.create_tasks()
assert result == expected_result
@patch("builtins.open", new_callable=mock_open, read_data="test_prompt")
def test_prioritize_tasks(mock_file):
expected_result = {"prompt": "test_prompt", "variables": ["goals", "instructions", "last_task", "last_task_result", "pending_tasks"]}
result = AgentPromptTemplate.prioritize_tasks()
assert result == expected_result
================================================
FILE: tests/unit_tests/agent/test_agent_tool_step_handler.py
================================================
import json
from unittest.mock import Mock, create_autospec, patch
import pytest
from superagi.agent.agent_tool_step_handler import AgentToolStepHandler
from superagi.agent.common_types import ToolExecutorResponse
from superagi.agent.output_handler import ToolOutputHandler
from superagi.agent.tool_builder import ToolBuilder
from superagi.helper.token_counter import TokenCounter
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.models.tool import Tool
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.agent_workflow_step_tool import AgentWorkflowStepTool
from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.tools.code.write_code import CodingTool
# Given
@pytest.fixture
def handler():
mock_session = Mock()
llm = Mock()
agent_id = 1
agent_execution_id = 1
# Creating an instance of the class to test
handler = AgentToolStepHandler(mock_session, llm, agent_id, agent_execution_id, None)
return handler
def test_create_permission_request(handler):
# Arrange
execution = Mock()
step_tool = Mock()
step_tool.input_instruction = "input_instruction"
handler.session.commit = Mock()
handler.session.flush = Mock()
mock_permission = create_autospec(AgentExecutionPermission)
with patch('superagi.agent.agent_tool_step_handler.AgentExecutionPermission', return_value=mock_permission) as mock_cls:
# Act
handler._create_permission_request(execution, step_tool)
# Assert
mock_cls.assert_called_once_with(
agent_execution_id=handler.agent_execution_id,
status="PENDING",
agent_id=handler.agent_id,
tool_name="WAIT_FOR_PERMISSION",
question=step_tool.input_instruction,
assistant_reply=""
)
handler.session.add.assert_called_once_with(mock_permission)
execution.permission_id = mock_permission.id
execution.status = "WAITING_FOR_PERMISSION"
assert handler.session.commit.call_count == 2
assert handler.session.flush.call_count == 1
def test_execute_step(handler):
# Arrange
execution = create_autospec(AgentExecution)
workflow_step = create_autospec(AgentWorkflowStep)
step_tool = create_autospec(AgentWorkflowStepTool)
agent_config = {}
agent_execution_config = {}
with patch.object(AgentExecution, 'get_agent_execution_from_id', return_value=execution), \
patch.object(AgentWorkflowStep, 'find_by_id', return_value=workflow_step), \
patch.object(AgentWorkflowStepTool, 'find_by_id', return_value=step_tool), \
patch.object(Agent, 'fetch_configuration', return_value=agent_config), \
patch.object(AgentExecutionConfiguration, 'fetch_configuration', return_value=agent_execution_config):
handler._handle_wait_for_permission = Mock(return_value=True)
handler._create_permission_request = Mock()
handler._process_input_instruction = Mock(return_value="{\"}")
handler._build_tool_obj = Mock()
handler._process_output_instruction = Mock(return_value="step_response")
handler._handle_next_step = Mock()
# Act
tool_output_handler = Mock(spec=ToolOutputHandler)
tool_output_handler.handle.return_value = ToolExecutorResponse(status="SUCCESS", output="final_response")
with patch('superagi.agent.agent_tool_step_handler.ToolOutputHandler', return_value=tool_output_handler):
# Act
handler.execute_step()
# Assert
handler._handle_wait_for_permission.assert_called_once()
handler._process_input_instruction.assert_called_once_with(agent_config, agent_execution_config, step_tool,
workflow_step)
handler._process_output_instruction.assert_called_once()
def test_handle_next_step_with_complete(handler):
# Arrange
next_step = "COMPLETE"
execution = create_autospec(AgentExecution)
with patch.object(AgentExecution, 'get_agent_execution_from_id', return_value=execution):
# Act
handler._handle_next_step(next_step)
# Assert
assert execution.current_agent_step_id == -1
assert execution.status == "COMPLETED"
handler.session.commit.assert_called_once()
def test_handle_next_step_with_next_step(handler):
# Arrange
next_step = create_autospec(AgentExecution) # Mocking the next_step object
execution = create_autospec(AgentExecution)
with patch.object(AgentExecution, 'get_agent_execution_from_id', return_value=execution), \
patch.object(AgentExecution, 'assign_next_step_id') as mock_assign_next_step_id:
# Act
handler._handle_next_step(next_step)
# Assert
mock_assign_next_step_id.assert_called_once_with(handler.session, handler.agent_execution_id, next_step.id)
handler.session.commit.assert_called_once()
def test_build_tool_obj(handler):
# Arrange
agent_config = {"model": "model1", "resource_summary": "summary"}
agent_execution_config = {}
tool_name = "QueryResourceTool"
model_api_key = {"provider":"provider","api_key":"apikey"}
resource_summary = "summary"
tool = Tool()
with patch.object(AgentConfiguration, 'get_model_api_key', return_value=model_api_key), \
patch.object(ToolBuilder, 'build_tool', return_value=tool), \
patch.object(ToolBuilder, 'set_default_params_tool', return_value=tool), \
patch.object(ResourceSummarizer, 'fetch_or_create_agent_resource_summary', return_value=resource_summary), \
patch.object(handler.session, 'query', return_value=Mock(first=Mock(return_value=tool))):
# Act
result = handler._build_tool_obj(agent_config, agent_execution_config, tool_name)
# Assert
assert result == tool
def test_process_output_instruction(handler):
# Arrange
final_response = "final_response"
step_tool = AgentWorkflowStepTool()
workflow_step = AgentWorkflowStep()
mock_response = {"content": "response_content"}
mock_model = Mock()
current_tokens = 10
token_limit = 100
with patch.object(handler, '_build_tool_output_prompt', return_value="prompt"), \
patch.object(TokenCounter, 'count_message_tokens', return_value=current_tokens), \
patch.object(TokenCounter, 'token_limit', return_value=token_limit), \
patch.object(handler.llm, 'chat_completion', return_value=mock_response), \
patch.object(AgentExecution, 'update_tokens'):
# Act
result = handler._process_output_instruction(final_response, step_tool, workflow_step)
# Assert
assert result == mock_response['content']
def test_build_tool_input_prompt(handler):
# Arrange
step_tool = AgentWorkflowStepTool()
step_tool.tool_name = "CodingTool"
step_tool.input_instruction = "TestInstruction"
tool = CodingTool()
# tool.name = "TestTool"
# tool.description = "TestDescription"
# tool.args = {"arg1": "val1"}
agent_execution_config = {"goal": ["Goal1", "Goal2"]}
mock_prompt = "{goals}{tool_name}{instruction}{tool_schema}"
with patch('superagi.agent.agent_tool_step_handler.PromptReader.read_agent_prompt', return_value=mock_prompt), \
patch('superagi.agent.agent_tool_step_handler.AgentPromptBuilder.add_list_items_to_string', return_value="Goal1, Goal2"):
# Act
result = handler._build_tool_input_prompt(step_tool, tool, agent_execution_config)
# Assert
result = result.replace("{goals}", "Goal1, Goal2")
result = result.replace("{tool_name}", step_tool.tool_name)
result = result.replace("{instruction}", step_tool.input_instruction)
tool_schema = f"\"{tool.name}\": {tool.description}, args json schema: {json.dumps(tool.args)}"
result = result.replace("{tool_schema}", tool_schema)
assert """Goal1, Goal2CodingToolTestInstruction""" in result
def test_build_tool_output_prompt(handler):
# Arrange
step_tool = AgentWorkflowStepTool()
step_tool.tool_name = "TestTool"
step_tool.output_instruction = "TestInstruction"
tool_output = "TestOutput"
workflow_step = AgentWorkflowStep()
expected_prompt = "TestOutputTestToolTestInstruction['option1', 'option2']"
mock_prompt = "{tool_output}{tool_name}{instruction}{output_options}"
step_responses = ["option1", "option2", "default"]
with patch('superagi.agent.agent_tool_step_handler.PromptReader.read_agent_prompt', return_value=mock_prompt), \
patch.object(handler, '_get_step_responses', return_value=step_responses):
# Act
result = handler._build_tool_output_prompt(step_tool, tool_output, workflow_step)
# Assert
expected_prompt = expected_prompt.replace("{tool_output}", tool_output)
expected_prompt = expected_prompt.replace("{tool_name}", step_tool.tool_name)
expected_prompt = expected_prompt.replace("{instruction}", step_tool.output_instruction)
expected_prompt = expected_prompt.replace("{output_options}", str(step_responses))
assert result == expected_prompt
def test_handle_wait_for_permission_approved(handler):
# Arrange
agent_execution = AgentExecution()
agent_execution.status = "WAITING_FOR_PERMISSION"
agent_execution.permission_id = 123
workflow_step = AgentWorkflowStep()
agent_execution_permission = AgentExecutionPermission()
agent_execution_permission.status = "APPROVED"
next_step = AgentWorkflowStep()
handler.session.query.return_value.filter.return_value.first.return_value = agent_execution_permission
handler._handle_next_step = Mock()
AgentWorkflowStep.fetch_next_step = Mock(return_value=next_step)
# Act
result = handler._handle_wait_for_permission(agent_execution, workflow_step)
# Assert
assert result == False
handler._handle_next_step.assert_called_once_with(next_step)
assert agent_execution.status == "RUNNING"
assert agent_execution.permission_id == -1
def test_handle_wait_for_permission_denied(handler):
# Arrange
agent_execution = AgentExecution()
agent_execution.status = "WAITING_FOR_PERMISSION"
agent_execution.permission_id = 123
workflow_step = AgentWorkflowStep()
agent_execution_permission = AgentExecutionPermission()
agent_execution_permission.status = "DENIED"
agent_execution_permission.user_feedback = "User feedback"
next_step = AgentWorkflowStep()
handler.session.query.return_value.filter.return_value.first.return_value = agent_execution_permission
handler._handle_next_step = Mock()
AgentWorkflowStep.fetch_next_step = Mock(return_value=next_step)
# Act
result = handler._handle_wait_for_permission(agent_execution, workflow_step)
# Assert
assert result == False
handler._handle_next_step.assert_called_once_with(next_step)
assert agent_execution.status == "RUNNING"
assert agent_execution.permission_id == -1
================================================
FILE: tests/unit_tests/agent/test_agent_workflow_step_wait_handler.py
================================================
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from superagi.models.agent_execution import AgentExecution
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.agent.agent_workflow_step_wait_handler import AgentWaitStepHandler
# Mock datetime.now() for testing
@pytest.fixture
def mock_datetime_now():
return datetime(2023, 9, 6, 12, 0, 0)
@pytest.fixture(autouse=True)
def mock_datetime_now_fixture(monkeypatch, mock_datetime_now):
monkeypatch.setattr("superagi.agent.agent_workflow_step_wait_handler.datetime",
MagicMock(now=MagicMock(return_value=mock_datetime_now)))
# Test cases
@patch.object(AgentExecution, 'get_agent_execution_from_id')
@patch.object(AgentWorkflowStep, 'find_by_id')
@patch.object(AgentWorkflowStep, 'fetch_next_step')
def test_handle_next_step_complete(mock_fetch_next_step, mock_find_by_id, mock_get_agent_execution_from_id, mock_datetime_now_fixture):
mock_session = MagicMock()
mock_agent_execution = MagicMock(current_agent_step_id=1, status="WAIT_STEP")
mock_get_agent_execution_from_id.return_value = mock_agent_execution
mock_find_by_id.return_value = MagicMock()
mock_next_step = MagicMock(id=2)
mock_next_step.__str__.return_value = "COMPLETE"
mock_fetch_next_step.return_value = mock_next_step
handler = AgentWaitStepHandler(mock_session, 1, 2)
handler.handle_next_step()
# Assertions
assert mock_agent_execution.current_agent_step_id == -1
assert mock_agent_execution.status == "COMPLETED"
mock_session.commit.assert_called_once()
# Test cases
@patch.object(AgentExecution, 'get_agent_execution_from_id')
@patch.object(AgentWorkflowStep, 'find_by_id')
@patch.object(AgentWorkflowStep, 'fetch_next_step')
def test_execute_step(mock_fetch_next_step, mock_find_by_id, mock_get_agent_execution_from_id):
mock_session = MagicMock()
mock_agent_execution = MagicMock(current_agent_step_id=1, status="WAIT_STEP")
mock_step_wait = MagicMock(status="WAITING")
mock_get_agent_execution_from_id.return_value = mock_agent_execution
mock_find_by_id.return_value = mock_step_wait
mock_fetch_next_step.return_value = MagicMock()
handler = AgentWaitStepHandler(mock_session, 1, 2)
handler.execute_step()
# Assertions
assert mock_step_wait.status == "WAITING"
assert mock_agent_execution.status == "WAIT_STEP"
mock_session.commit.assert_called_once()
================================================
FILE: tests/unit_tests/agent/test_output_handler.py
================================================
import pytest
from unittest.mock import Mock, patch, MagicMock
from superagi.agent.common_types import ToolExecutorResponse
from superagi.agent.output_handler import ToolOutputHandler, TaskOutputHandler, ReplaceTaskOutputHandler
from superagi.agent.output_parser import AgentSchemaOutputParser, AgentGPTAction
from superagi.agent.task_queue import TaskQueue
from superagi.agent.tool_executor import ToolExecutor
from superagi.helper.json_cleaner import JsonCleaner
from superagi.models.agent import Agent
from superagi.models.agent_execution_permission import AgentExecutionPermission
import numpy as np
from superagi.agent.output_handler import ToolOutputHandler
# Test for ToolOutputHandler
@patch.object(TaskQueue, 'complete_task')
@patch.object(TaskQueue, 'get_tasks')
@patch.object(TaskQueue, 'get_completed_tasks')
@patch.object(AgentSchemaOutputParser, 'parse')
def test_tool_output_handle(parse_mock, execute_mock, get_completed_tasks_mock, complete_task_mock):
# Arrange
agent_execution_id = 11
agent_config = {"agent_id": 22, "permission_type": "unrestricted"}
assistant_reply = '{"tool": {"name": "someAction", "args": ["arg1", "arg2"]}}'
parse_mock.return_value = AgentGPTAction(name="someAction", args=["arg1", "arg2"])
# Define what the mock response status should be
execute_mock.return_value = Mock(status='PENDING', is_permission_required=False)
handler = ToolOutputHandler(agent_execution_id, agent_config, [],None)
# Mock session
session_mock = MagicMock()
session_mock.query.return_value.filter.return_value.first.return_value = Mock()
handler._check_for_completion = Mock(return_value=Mock(status='PENDING', is_permission_required=False))
handler.handle_tool_response = Mock(return_value=Mock(status='PENDING', is_permission_required=False))
# Act
response = handler.handle(session_mock, assistant_reply)
# Assert
assert response.status == "PENDING"
parse_mock.assert_called_with(assistant_reply)
assert session_mock.add.call_count == 2
@patch('superagi.agent.output_handler.TokenTextSplitter')
def test_add_text_to_memory(TokenTextSplitter_mock):
# Arrange
agent_execution_id = 1
agent_config = {"agent_id": 2}
tool_output_handler = ToolOutputHandler(agent_execution_id, agent_config,[], None)
assistant_reply = '{"thoughts": {"text": "This is a task."}}'
tool_response_result = '["Task completed."]'
text_splitter_mock = MagicMock()
TokenTextSplitter_mock.return_value = text_splitter_mock
text_splitter_mock.split_text.return_value = ["This is a task.", "Task completed."]
# Mock the VectorStore memory
memory_mock = MagicMock()
tool_output_handler.memory = memory_mock
# Act
tool_output_handler.add_text_to_memory(assistant_reply, tool_response_result)
# Assert
TokenTextSplitter_mock.assert_called_once_with(chunk_size=1024, chunk_overlap=10)
text_splitter_mock.split_text.assert_called_once_with('This is a task.["Task completed."]')
memory_mock.add_texts.assert_called_once_with(["This is a task.", "Task completed."], [{"agent_execution_id": agent_execution_id}, {"agent_execution_id": agent_execution_id}])
@patch('superagi.models.agent_execution_permission.AgentExecutionPermission')
def test_tool_handler_check_permission_in_restricted_mode(op_mock):
# Mock the session
session_mock = MagicMock()
# Arrange
agent_execution_id = 1
agent_config = {"agent_id": 2, "permission_type": "RESTRICTED"}
assistant_reply = '{"tool": {"name": "someAction", "args": ["arg1", "arg2"]}}'
op_mock.parse.return_value = AgentGPTAction(name="someAction", args=["arg1", "arg2"])
tool = MagicMock()
tool.name = "someAction"
tool.permission_required = True
handler = ToolOutputHandler(agent_execution_id, agent_config, [tool],None)
# Act
response = handler._check_permission_in_restricted_mode(session_mock, assistant_reply)
# Assert
assert response.is_permission_required
assert response.status == "WAITING_FOR_PERMISSION"
session_mock.add.assert_called_once()
session_mock.commit.assert_called_once()
# Test for TaskOutputHandler
@patch.object(TaskQueue, 'add_task')
@patch.object(TaskQueue, 'get_tasks')
@patch.object(JsonCleaner, 'extract_json_array_section')
def test_task_output_handle_method(extract_json_array_section_mock, get_tasks_mock, add_task_mock):
# Arrange
agent_execution_id = 1
agent_config = {"agent_id": 2}
assistant_reply = '["task1", "task2", "task3"]'
tasks = ["task1", "task2", "task3"]
extract_json_array_section_mock.return_value = str(tasks)
get_tasks_mock.return_value = tasks
handler = TaskOutputHandler(agent_execution_id, agent_config)
# Mock session
session_mock = MagicMock()
# Act
response = handler.handle(session_mock, assistant_reply)
# Assert
extract_json_array_section_mock.assert_called_once_with(assistant_reply)
assert add_task_mock.call_count == len(tasks)
assert session_mock.add.call_count == len(tasks)
get_tasks_mock.assert_called_once()
assert response.status == "PENDING"
# Test for ReplaceTaskOutputHandler
@patch.object(TaskQueue, 'clear_tasks')
@patch.object(TaskQueue, 'add_task')
@patch.object(TaskQueue, 'get_tasks')
@patch.object(JsonCleaner, 'extract_json_array_section')
def test_handle_method(extract_json_array_section_mock, get_tasks_mock, add_task_mock, clear_tasks_mock):
# Arrange
agent_execution_id = 1
agent_config = {}
assistant_reply = '["task1", "task2", "task3"]'
tasks = ["task1", "task2", "task3"]
extract_json_array_section_mock.return_value = str(tasks)
get_tasks_mock.return_value = tasks
handler = ReplaceTaskOutputHandler(agent_execution_id, agent_config)
# Mock session
session_mock = MagicMock()
# Act
response = handler.handle(session_mock, assistant_reply)
# Assert
extract_json_array_section_mock.assert_called_once_with(assistant_reply)
clear_tasks_mock.assert_called_once()
assert add_task_mock.call_count == len(tasks)
get_tasks_mock.assert_called_once()
assert response.status == "PENDING"
================================================
FILE: tests/unit_tests/agent/test_output_parser.py
================================================
import pytest
from superagi.agent.output_parser import AgentGPTAction, AgentSchemaOutputParser
import pytest
def test_agent_schema_output_parser():
parser = AgentSchemaOutputParser()
# Test with valid json response
response = '```{"tool": {"name": "Tool1", "args": {}}}```'
parsed = parser.parse(response)
assert isinstance(parsed, AgentGPTAction)
assert parsed.name == 'Tool1'
assert parsed.args == {}
# Test with valid json but with boolean values
response = "```{'tool': {'name': 'Tool1', 'args': 'arg1'}, 'status': True}```"
parsed = parser.parse(response)
assert isinstance(parsed, AgentGPTAction)
assert parsed.name == 'Tool1'
assert parsed.args == 'arg1'
# Test with invalid json response
response = "invalid response"
with pytest.raises(Exception):
parsed = parser.parse(response)
# Test with empty json response
response = ""
with pytest.raises(Exception):
parsed = parser.parse(response)
================================================
FILE: tests/unit_tests/agent/test_queue_step_handler.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.agent.queue_step_handler import QueueStepHandler
# To prevent having to patch each time, setup a pytest fixture
@pytest.fixture
def queue_step_handler():
# Mock dependencies
session = Mock()
llm = Mock()
agent_id = 1
agent_execution_id = 1
# Instantiate your class with the mocked dependencies
return QueueStepHandler(session, llm, agent_id, agent_execution_id)
@pytest.fixture
def step_tool():
step_tool = Mock()
step_tool.unique_id = "unique_id"
step_tool.input_instruction = "input_instruction"
return step_tool
def test_queue_identifier(queue_step_handler):
step_tool = Mock()
step_tool.unique_id = "step_id"
assert queue_step_handler._queue_identifier(step_tool) == "step_id_1"
@patch("superagi.agent.queue_step_handler.AgentExecution") # Replace with your actual module path
@patch("superagi.agent.queue_step_handler.AgentWorkflowStep")
@patch("superagi.agent.queue_step_handler.AgentWorkflowStepTool")
@patch("superagi.agent.queue_step_handler.TaskQueue")
def test_execute_step(task_queue_mock, agent_execution_mock, workflow_step_mock, step_tool_mock, queue_step_handler):
agent_execution_mock.get_agent_execution_from_id.return_value = Mock(current_agent_step_id="step_id")
workflow_step_mock.find_by_id.return_value = Mock(action_reference_id="action_id")
step_tool_mock.find_by_id.return_value = Mock()
task_queue_mock.return_value.get_status.return_value = None # Mock the get_status method on TaskQueue
# Here you can add assertions depending on what you expect
# For example if you expect the return value to be "default", you could do
assert queue_step_handler.execute_step() == "default"
@patch("superagi.agent.queue_step_handler.TaskQueue")
@patch("superagi.agent.queue_step_handler.AgentExecutionFeed")
def test_add_to_queue(task_queue_mock, agent_execution_feed_mock, queue_step_handler, step_tool):
# Setup mocks
queue_step_handler._process_input_instruction = Mock(return_value='{"reply": ["task1", "task2"]}')
queue_step_handler._process_reply = Mock()
# Call the method
queue_step_handler._add_to_queue(task_queue_mock, step_tool)
# Verify the calls
queue_step_handler._process_input_instruction.assert_called_once_with(step_tool)
queue_step_handler._process_reply.assert_called_once_with(task_queue_mock, '{"reply": ["task1", "task2"]}')
@patch("superagi.agent.queue_step_handler.TaskQueue")
@patch("superagi.agent.queue_step_handler.AgentExecutionFeed")
def test_consume_from_queue(task_queue_mock, agent_execution_feed_mock, queue_step_handler, step_tool):
# Setup mocks
task_queue_mock.get_tasks.return_value = ['task1', 'task2']
task_queue_mock.get_first_task.return_value = 'task1'
agent_execution_feed_instance = agent_execution_feed_mock.return_value
# Call the method
queue_step_handler._consume_from_queue(task_queue_mock)
# Verify the calls
queue_step_handler.session.commit.assert_called() # Ensure session commits were called
queue_step_handler.session.add.assert_called()
task_queue_mock.complete_task.assert_called_once_with("PROCESSED")
================================================
FILE: tests/unit_tests/agent/test_task_queue.py
================================================
import unittest
from unittest.mock import patch
from superagi.agent.task_queue import TaskQueue
class TaskQueueTests(unittest.TestCase):
def setUp(self):
self.queue_name = "test_queue"
self.queue = TaskQueue(self.queue_name)
@patch.object(TaskQueue, 'add_task')
def test_add_task(self, mock_add_task):
task = "Do something"
self.queue.add_task(task)
mock_add_task.assert_called_with(task)
@patch.object(TaskQueue, 'complete_task')
def test_complete_task(self, mock_complete_task):
task = "Do something"
response = "Task completed"
self.queue.complete_task(response)
mock_complete_task.assert_called_with(response)
@patch.object(TaskQueue, 'get_first_task')
def test_get_first_task(self, mock_get_first_task):
self.queue.get_first_task()
mock_get_first_task.assert_called()
@patch.object(TaskQueue, 'get_tasks')
def test_get_tasks(self, mock_get_tasks):
self.queue.get_tasks()
mock_get_tasks.assert_called()
@patch.object(TaskQueue, 'get_completed_tasks')
def test_get_completed_tasks(self, mock_get_completed_tasks):
self.queue.get_completed_tasks()
mock_get_completed_tasks.assert_called()
@patch.object(TaskQueue, 'clear_tasks')
def test_clear_tasks(self, mock_clear_tasks):
self.queue.clear_tasks()
mock_clear_tasks.assert_called()
@patch.object(TaskQueue, 'get_last_task_details')
def test_get_last_task_details(self, mock_get_last_task_details):
self.queue.get_last_task_details()
mock_get_last_task_details.assert_called()
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/agent/test_tool_builder.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.agent.tool_builder import ToolBuilder
from superagi.models.tool import Tool
@pytest.fixture
def session():
return Mock()
@pytest.fixture
def agent_id():
return 1
@pytest.fixture
def tool_builder(session, agent_id):
return ToolBuilder(session, agent_id)
@pytest.fixture
def tool():
tool = Mock(spec=Tool)
tool.file_name = 'test.py'
tool.folder_name = 'test_folder'
tool.class_name = 'TestClass'
return tool
@pytest.fixture
def agent_config():
return {"model": "gpt4"}
@pytest.fixture
def agent_execution_config():
return {"goal": "Test Goal", "instruction": "Test Instruction"}
@patch('superagi.agent.tool_builder.importlib.import_module')
@patch('superagi.agent.tool_builder.getattr')
def test_build_tool(mock_getattr, mock_import_module, tool_builder, tool):
mock_module = Mock()
mock_class = Mock()
mock_import_module.return_value = mock_module
mock_getattr.return_value = mock_class
result_tool = tool_builder.build_tool(tool)
mock_import_module.assert_called_with('.test_folder.test')
mock_getattr.assert_called_with(mock_module, tool.class_name)
assert result_tool.toolkit_config.session == tool_builder.session
assert result_tool.toolkit_config.toolkit_id == tool.toolkit_id
================================================
FILE: tests/unit_tests/agent/test_tool_executor.py
================================================
import pytest
from unittest.mock import Mock, patch
from pydantic import ValidationError
from superagi.agent.common_types import ToolExecutorResponse
from superagi.agent.tool_executor import ToolExecutor
class MockTool:
def __init__(self, name):
self.name = name
def execute(self, args):
return self.name
@pytest.fixture
def mock_tools():
return [MockTool(name=f'tool{i}') for i in range(5)]
@pytest.fixture
def executor(mock_tools):
return ToolExecutor(organisation_id=1, agent_id=1, tools=mock_tools, agent_execution_id=1)
def test_tool_executor_finish(executor):
res = executor.execute(None, 'finish', {})
assert res.status == 'COMPLETE'
assert res.result == ''
@patch('superagi.agent.tool_executor.EventHandler')
def test_tool_executor_success(mock_event_handler, executor, mock_tools):
for i, tool in enumerate(mock_tools):
res = executor.execute(None, f'tool{i}', {'agent_execution_id': 1})
assert res.status == 'SUCCESS'
assert res.result == f'Tool {tool.name} returned: {tool.name}'
assert res.retry == False
@patch('superagi.agent.tool_executor.EventHandler')
def test_tool_executor_generic_error(mock_event_handler, executor):
tool = MockTool('error_tool')
tool.execute = Mock(side_effect=Exception('generic error'))
executor.tools.append(tool)
res = executor.execute(None, 'error_tool', {})
assert res.status == 'ERROR'
assert 'Error1: generic error' in res.result
assert res.retry == True
def test_tool_executor_unknown_tool(executor):
res = executor.execute(None, 'unknown_tool', {})
assert res.status == 'ERROR'
assert "Unknown tool 'unknown_tool'" in res.result
assert res.retry == True
def test_clean_tool_args(executor):
args = {"arg1": {"value": 1}, "arg2": 2}
clean_args = executor.clean_tool_args(args)
assert clean_args == {"arg1": 1, "arg2": 2}
================================================
FILE: tests/unit_tests/apm/__init__.py
================================================
================================================
FILE: tests/unit_tests/apm/test_analytics_helper.py
================================================
import pytest
from superagi.models.events import Event
from superagi.apm.analytics_helper import AnalyticsHelper
from unittest.mock import MagicMock
@pytest.fixture
def organisation_id():
return 1
@pytest.fixture
def mock_session():
return MagicMock()
@pytest.fixture
def analytics_helper(mock_session, organisation_id):
return AnalyticsHelper(mock_session, organisation_id)
def test_calculate_run_completed_metrics(analytics_helper, mock_session):
mock_session.query().all.return_value = [MagicMock()]
result = analytics_helper.calculate_run_completed_metrics()
assert isinstance(result, dict)
def test_fetch_agent_data(analytics_helper, mock_session):
mock_session.query().all.return_value = [MagicMock()]
result = analytics_helper.fetch_agent_data()
assert isinstance(result, dict)
def test_fetch_agent_runs(analytics_helper, mock_session):
mock_session.query().all.return_value = [MagicMock()]
result = analytics_helper.fetch_agent_runs(1)
assert isinstance(result, list)
def test_get_active_runs(analytics_helper, mock_session):
mock_session.query().all.return_value = [MagicMock()]
result = analytics_helper.get_active_runs()
assert isinstance(result, list)
================================================
FILE: tests/unit_tests/apm/test_call_log_helper.py
================================================
import pytest
from sqlalchemy.exc import SQLAlchemyError
from superagi.models.call_logs import CallLogs
from superagi.models.agent import Agent
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
from unittest.mock import MagicMock
from superagi.apm.call_log_helper import CallLogHelper
@pytest.fixture
def mock_session():
return MagicMock()
@pytest.fixture
def mock_agent():
return MagicMock()
@pytest.fixture
def mock_tool():
return MagicMock()
@pytest.fixture
def mock_toolkit():
return MagicMock()
@pytest.fixture
def call_log_helper(mock_session):
return CallLogHelper(mock_session, 1)
def test_create_call_log_success(call_log_helper, mock_session):
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
call_log = call_log_helper.create_call_log('test', 1, 10, 'test_tool', 'test_model')
assert isinstance(call_log, CallLogs)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_call_log_failure(call_log_helper, mock_session):
mock_session.commit = MagicMock(side_effect=SQLAlchemyError())
call_log = call_log_helper.create_call_log('test', 1, 10, 'test_tool', 'test_model')
assert call_log is None
def test_fetch_data_success(call_log_helper, mock_session):
mock_session.query = MagicMock()
# creating mock results
summary_result = (1, 1, 1)
runs = [CallLogs(
agent_execution_name='test',
agent_id=1,
tokens_consumed=10,
tool_used='test_tool',
model='test_model',
org_id=1
)]
agents = [Agent(name='test_agent')]
tools = [Tool(name='test_tool', toolkit_id=1)]
toolkits = [Toolkit(name='test_toolkit')]
# setup return values for the mock methods
mock_session.query().filter().first.side_effect = [summary_result, runs, agents, toolkits, tools]
result = call_log_helper.fetch_data('test_model')
assert result is not None
assert 'model' in result
assert 'total_tokens' in result
assert 'total_calls' in result
assert 'total_agents' in result
assert 'runs' in result
def test_fetch_data_failure(call_log_helper, mock_session):
mock_session.query = MagicMock(side_effect=SQLAlchemyError())
result = call_log_helper.fetch_data('test_model')
assert result is None
================================================
FILE: tests/unit_tests/apm/test_event_handler.py
================================================
import pytest
from sqlalchemy.exc import SQLAlchemyError
from superagi.models.events import Event
from unittest.mock import MagicMock
from superagi.apm.event_handler import EventHandler
@pytest.fixture
def mock_session():
return MagicMock()
@pytest.fixture
def event_handler(mock_session):
return EventHandler(mock_session)
def test_create_event_success(event_handler, mock_session):
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
event = event_handler.create_event('test', {}, 1, 1, 100)
assert isinstance(event, Event)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_event_failure(event_handler, mock_session):
mock_session.commit = MagicMock(side_effect=SQLAlchemyError())
event = event_handler.create_event('test', {}, 1, 1, 100)
assert event is None
================================================
FILE: tests/unit_tests/apm/test_knowledge_handler.py
================================================
import pytest
from unittest.mock import MagicMock
from superagi.apm.knowledge_handler import KnowledgeHandler
from fastapi import HTTPException
from datetime import datetime
import pytz
@pytest.fixture
def organisation_id():
return 1
@pytest.fixture
def mock_session():
return MagicMock()
@pytest.fixture
def knowledge_handler(mock_session, organisation_id):
return KnowledgeHandler(mock_session, organisation_id)
def test_get_knowledge_usage_by_name(knowledge_handler, mock_session):
knowledge_handler.session = mock_session
knowledge_name = 'Knowledge1'
mock_knowledge_event = MagicMock()
mock_knowledge_event.knowledge_unique_agents = 5
mock_knowledge_event.knowledge_name = knowledge_name
mock_knowledge_event.id = 1
mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = mock_knowledge_event
mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_knowledge_event
mock_session.query.return_value.filter.return_value.count.return_value = 10
result = knowledge_handler.get_knowledge_usage_by_name(knowledge_name)
assert isinstance(result, dict)
assert result == {
'knowledge_unique_agents': 5,
'knowledge_calls': 10
}
mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = None
with pytest.raises(HTTPException):
knowledge_handler.get_knowledge_usage_by_name('NonexistentKnowledge')
def test_get_knowledge_events_by_name(knowledge_handler, mock_session):
knowledge_name = 'knowledge1'
knowledge_handler.session = mock_session
knowledge_handler.organisation_id = 1
mock_knowledge = MagicMock()
mock_knowledge.id = 1
mock_session.query().filter_by().filter().first.return_value = mock_knowledge
result_obj = MagicMock()
result_obj.agent_id = 1
result_obj.created_at = datetime.now()
result_obj.event_name = 'knowledge_picked'
result_obj.event_property = {'knowledge_name': 'knowledge1', 'agent_execution_id': '1'}
result_obj2 = MagicMock()
result_obj2.agent_id = 1
result_obj2.event_name = 'run_completed'
result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'}
result_obj3 = MagicMock()
result_obj3.agent_id = 1
result_obj3.event_name = 'agent_created'
result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'}
mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3]]
user_timezone = MagicMock()
user_timezone.value = 'America/New_York'
mock_session.query().filter().first.return_value = user_timezone
result = knowledge_handler.get_knowledge_events_by_name(knowledge_name)
assert isinstance(result, list)
assert len(result) == 1
for item in result:
assert 'agent_execution_id' in item
assert 'created_at' in item
assert 'tokens_consumed' in item
assert 'calls' in item
assert 'agent_execution_name' in item
assert 'agent_name' in item
assert 'model' in item
def test_get_knowledge_events_by_name_knowledge_not_found(knowledge_handler, mock_session):
knowledge_name = "knowledge1"
not_found_message = 'Knowledge not found'
mock_session.query().filter_by().filter().first.return_value = None
try:
knowledge_handler.get_knowledge_events_by_name(knowledge_name)
assert False, "Expected HTTPException has not been raised"
except HTTPException as e:
assert str(e.detail) == not_found_message, f"Expected {not_found_message}, got {e.detail}"
finally:
assert mock_session.query().filter_by().filter().first.called, "first() function not called"
================================================
FILE: tests/unit_tests/apm/test_tools_handler.py
================================================
import pytest
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from superagi.apm.tools_handler import ToolsHandler
from sqlalchemy.orm import Session
from superagi.models.agent_config import AgentConfiguration
from datetime import datetime
import pytz
@pytest.fixture
def organisation_id():
return 1
@pytest.fixture
def mock_session():
return MagicMock()
@pytest.fixture
def tools_handler(mock_session, organisation_id):
return ToolsHandler(mock_session, organisation_id)
def test_calculate_tool_usage(tools_handler, mock_session):
tool_used_subquery = MagicMock()
agent_count_subquery = MagicMock()
total_usage_subquery = MagicMock()
tool_used_subquery.c.tool_name = 'Tool1'
tool_used_subquery.c.agent_id = 1
agent_count_subquery.c.tool_name = 'Tool1'
agent_count_subquery.c.unique_agents = 1
total_usage_subquery.c.tool_name = 'Tool1'
total_usage_subquery.c.total_usage = 5
mock_session.query.return_value.filter_by.return_value.subquery.return_value = tool_used_subquery
mock_session.query.return_value.group_by.return_value.subquery.side_effect = [agent_count_subquery, total_usage_subquery]
result_obj = MagicMock()
result_obj.tool_name = 'Tool1'
result_obj.unique_agents = 1
result_obj.total_usage = 5
mock_session.query.return_value.join.return_value.all.return_value = [result_obj]
tools_handler.get_tool_and_toolkit = MagicMock(return_value={'tool1': 'Toolkit1'})
result = tools_handler.calculate_tool_usage()
assert isinstance(result, list)
expected_output = [{'tool_name': 'Tool1', 'unique_agents': 1, 'total_usage': 5, 'toolkit': 'Toolkit1'}]
assert result == expected_output
def test_get_tool_and_toolkit(tools_handler, mock_session):
result_obj = MagicMock()
result_obj.tool_name = 'tool 1'
result_obj.toolkit_name = 'toolkit 1'
mock_session.query().join().all.return_value = [result_obj]
output = tools_handler.get_tool_and_toolkit()
assert isinstance(output, dict)
assert output == {'tool 1': 'toolkit 1'}
def test_get_tool_usage_by_name(tools_handler, mock_session):
tools_handler.session = mock_session
tool_name = 'Tool1'
formatted_tool_name = tool_name.lower().replace(" ", "")
mock_tool = MagicMock()
mock_tool.name = tool_name
mock_tool_event = MagicMock()
mock_tool_event.tool_name = formatted_tool_name
mock_tool_event.tool_calls = 10
mock_tool_event.tool_unique_agents = 5
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_tool
mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_tool_event
result = tools_handler.get_tool_usage_by_name(tool_name=tool_name)
assert isinstance(result, dict)
assert result == {
'tool_calls': 10,
'tool_unique_agents': 5
}
mock_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(HTTPException):
tools_handler.get_tool_usage_by_name(tool_name="NonexistentTool")
def test_get_tool_events_by_name(tools_handler, mock_session):
tool_name = 'Tool1'
tools_handler.session = mock_session
tools_handler.organisation_id = 1
mock_tool = MagicMock()
mock_tool.id = 1
mock_session.query().filter_by().first.return_value = mock_tool
result_obj = MagicMock()
result_obj.agent_id = 1
result_obj.id = 1
result_obj.created_at = datetime.now()
result_obj.event_name = 'tool_used'
result_obj.event_property = {'tool_name': 'tool1', 'agent_execution_id': '1'}
result_obj2 = MagicMock()
result_obj2.agent_id = 1
result_obj2.id = 2
result_obj2.event_name = 'run_completed'
result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'}
result_obj3 = MagicMock()
result_obj3.agent_id = 1
result_obj3.event_name = 'agent_created'
result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'}
mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3], []]
user_timezone = MagicMock()
user_timezone.value = 'America/New_York'
mock_session.query().filter().first.return_value = user_timezone
result = tools_handler.get_tool_events_by_name(tool_name)
assert isinstance(result, list)
assert len(result) == 1
for item in result:
assert 'agent_execution_id' in item
assert 'created_at' in item
assert 'tokens_consumed' in item
assert 'calls' in item
assert 'agent_execution_name' in item
assert 'agent_name' in item
assert 'model' in item
def test_get_tool_events_by_name_tool_not_found(tools_handler, mock_session):
tool_name = "tool1"
mock_session.query().filter_by().first.return_value = None
with pytest.raises(HTTPException):
tools_handler.get_tool_events_by_name(tool_name)
assert mock_session.query().filter_by().first.called
================================================
FILE: tests/unit_tests/controllers/__init__.py
================================================
================================================
FILE: tests/unit_tests/controllers/api/__init__.py
================================================
================================================
FILE: tests/unit_tests/controllers/api/test_agent.py
================================================
import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
import superagi.config.config
from unittest.mock import MagicMock, patch,Mock
from main import app
from unittest.mock import patch,create_autospec
from sqlalchemy.orm import Session
from superagi.controllers.api.agent import ExecutionStateChangeConfigIn,AgentConfigUpdateExtInput
from superagi.models.agent import Agent
from superagi.models.project import Project
client = TestClient(app)
@pytest.fixture
def mock_api_key_get():
mock_api_key = "your_mock_api_key"
return mock_api_key
@pytest.fixture
def mock_execution_state_change_input():
return {
}
@pytest.fixture
def mock_run_id_config():
return {
"run_ids":[1,2]
}
@pytest.fixture
def mock_agent_execution():
return {
}
@pytest.fixture
def mock_run_id_config_empty():
return {
"run_ids":[]
}
@pytest.fixture
def mock_run_id_config_invalid():
return {
"run_ids":[12310]
}
@pytest.fixture
def mock_agent_config_update_ext_input():
return AgentConfigUpdateExtInput(
tools=[{"name":"Image Generation Toolkit"}],
schedule=None,
goal=["Test Goal"],
instruction=["Test Instruction"],
constraints=["Test Constraints"],
iteration_interval=10,
model="Test Model",
max_iterations=100,
agent_type="Test Agent Type"
)
@pytest.fixture
def mock_update_agent_config():
return {
"name": "agent_3_UPDATED",
"description": "AI assistant to solve complex problems",
"goal": ["create a photo of a cat"],
"agent_type": "Dynamic Task Workflow",
"constraints": [
"~4000 word limit for short term memory.",
"Your long term memory is short, so immediately save important information to files.",
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
"No user assistance",
"Exclusively use the commands listed in double quotes e.g. \"command name\""
],
"instruction": ["Be accurate"],
"tools":[
{
"name":"Image Generation Toolkit"
}
],
"iteration_interval": 500,
"model": "gpt-4",
"max_iterations": 100
}
# Define test cases
def test_update_agent_not_found(mock_update_agent_config,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.put(
"/v1/agent/1",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_update_agent_config
)
assert response.status_code == 404
assert response.text == '{"detail":"Agent not found"}'
def test_get_run_resources_no_run_ids(mock_run_id_config_empty,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock, \
patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.post(
"v1/agent/resources/output",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_run_id_config_empty
)
assert response.status_code == 404
assert response.text == '{"detail":"No execution_id found"}'
def test_get_run_resources_invalid_run_ids(mock_run_id_config_invalid,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock, \
patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.post(
"v1/agent/resources/output",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_run_id_config_invalid
)
assert response.status_code == 404
assert response.text == '{"detail":"One or more run id(s) not found"}'
def test_resume_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.post(
"/v1/agent/1/resume",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_execution_state_change_input
)
assert response.status_code == 404
assert response.text == '{"detail":"Agent not found"}'
def test_pause_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.post(
"/v1/agent/1/pause",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_execution_state_change_input
)
assert response.status_code == 404
assert response.text == '{"detail":"Agent not found"}'
def test_create_run_agent_not_found(mock_agent_execution,mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock:
# Mock the session
mock_session = create_autospec(Session)
# # Configure session query methods to return None for agent
mock_session.query.return_value.filter.return_value.first.return_value = None
response = client.post(
"/v1/agent/1/run",
headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers
json=mock_agent_execution
)
assert response.status_code == 404
assert response.text == '{"detail":"Agent not found"}'
def test_create_run_project_not_matching_org(mock_agent_execution, mock_api_key_get):
with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \
patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.api.agent.db') as db_mock:
# Mock the session and configure query methods to return agent and project
mock_session = create_autospec(Session)
mock_agent = Agent(id=1, project_id=1, agent_workflow_id=1)
mock_session.query.return_value.filter.return_value.first.return_value = mock_agent
mock_project = Project(id=1, organisation_id=2) # Different organisation ID
db_mock.Project.find_by_id.return_value = mock_project
db_mock.session.return_value.__enter__.return_value = mock_session
response = client.post(
"/v1/agent/1/run",
headers={"X-API-Key": mock_api_key_get},
json=mock_agent_execution
)
assert response.status_code == 404
assert response.text == '{"detail":"Agent not found"}'
================================================
FILE: tests/unit_tests/controllers/test_agent.py
================================================
from unittest.mock import patch, Mock
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent import Agent
from datetime import datetime, timedelta
from pytz import timezone
client = TestClient(app)
@pytest.fixture
def mock_patch_schedule_input():
return{
"agent_id": 1,
"start_time": "2023-02-02 01:00:00",
"recurrence_interval": "2 Hours",
"expiry_date": "2023-12-30 01:00:00",
"expiry_runs": -1
}
@pytest.fixture
def mock_schedule():
# Mock schedule data for testing
return AgentSchedule(id=1, agent_id=1, status="SCHEDULED")
@pytest.fixture
def mock_agent_config():
return AgentConfiguration(key="user_timezone", agent_id=1, value='GMT')
@pytest.fixture
def mock_schedule_get():
return AgentSchedule(
id=1,
agent_id=1,
status="SCHEDULED",
start_time= datetime(2022, 1, 1, 10, 30),
recurrence_interval="5 Minutes",
expiry_date=datetime(2022, 1, 1, 10, 30) + timedelta(days=10),
expiry_runs=5
)
'''Test for Stopping Agent Scheduling'''
def test_stop_schedule_success(mock_schedule):
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_schedule
# Call the endpoint
response = client.post("agents/stop/schedule?agent_id=1")
# Verify the HTTP response
assert response.status_code == 200
# Verify changes in the mock agent schedule
assert mock_schedule.status == "STOPPED"
def test_stop_schedule_not_found():
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.filter.return_value.first.return_value = None
# Call the endpoint
response = client.post("agents/stop/schedule?agent_id=1")
# Verify the HTTP response
assert response.status_code == 404
assert response.json() == {"detail": "Schedule not found"}
'''Test for editing agent schedule'''
def test_edit_schedule_success(mock_schedule, mock_patch_schedule_input):
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_schedule
# Call the endpoint
response = client.put("agents/edit/schedule", json=mock_patch_schedule_input)
# Verify the HTTP response
assert response.status_code == 200
start_time = datetime.strptime(mock_patch_schedule_input["start_time"], "%Y-%m-%d %H:%M:%S")
expiry_date = datetime.strptime(mock_patch_schedule_input["expiry_date"], "%Y-%m-%d %H:%M:%S")
# Verify changes in the mock agent schedule
assert mock_schedule.start_time == start_time
assert mock_schedule.recurrence_interval == mock_patch_schedule_input["recurrence_interval"]
assert mock_schedule.expiry_date == expiry_date
assert mock_schedule.expiry_runs == mock_patch_schedule_input["expiry_runs"]
def test_edit_schedule_not_found(mock_patch_schedule_input):
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.filter.return_value.first.return_value = None
# Call the endpoint
response = client.put("agents/edit/schedule", json=mock_patch_schedule_input)
# Verify the HTTP response
assert response.status_code == 404
assert response.json() == {"detail": "Schedule not found"}
'''Test for getting agent schedule'''
def test_get_schedule_data_success(mock_schedule_get, mock_agent_config):
with patch('superagi.controllers.agent.db') as mock_db:
mock_db.session.query.return_value.filter.return_value.first.side_effect = [mock_schedule_get, mock_agent_config]
response = client.get("agents/get/schedule_data/1")
assert response.status_code == 200
time_gmt = mock_schedule_get.start_time.astimezone(timezone('GMT'))
expected_data = {
"current_datetime": mock.ANY,
"start_date": time_gmt.strftime("%d %b %Y"),
"start_time": time_gmt.strftime("%I:%M %p"),
"recurrence_interval": mock_schedule_get.recurrence_interval,
"expiry_date": mock_schedule_get.expiry_date.astimezone(timezone('GMT')).strftime("%d/%m/%Y"),
"expiry_runs": mock_schedule_get.expiry_runs,
}
assert response.json() == expected_data
def test_get_schedule_data_not_found():
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.filter.return_value.first.return_value = None
# Call the endpoint
response = client.get("agents/get/schedule_data/1")
# Verify the HTTP response
assert response.status_code == 404
assert response.json() == {"detail": "Agent Schedule not found"}
@pytest.fixture
def mock_agent_config_schedule():
return {
"agent_config": {
"name": "SmartAGI",
"project_id": 1,
"description": "AI assistant to solve complex problems",
"goal": ["Share research on latest google news in fashion"],
"agent_workflow": "Don't Maintain Task Queue",
"constraints": [
"~4000 word limit for short term memory.",
"No user assistance",
"Exclusively use the commands listed in double quotes"
],
"instruction": [],
"exit": "Exit strategy",
"iteration_interval": 500,
"model": "gpt-4",
"permission_type": "Type 1",
"LTM_DB": "Database Pinecone",
"toolkits": [1],
"tools": [],
"memory_window": 10,
"max_iterations": 25,
"user_timezone": "Asia/Kolkata"
},
"schedule": {
"start_time": "2023-07-04 11:13:00",
"expiry_runs": -1,
"recurrence_interval": None,
"expiry_date": None
}
}
@pytest.fixture
def mock_agent():
agent = Agent(id=1, name="SmartAGI", project_id=1)
return agent
def test_create_and_schedule_agent_success(mock_agent_config_schedule, mock_agent, mock_schedule):
with patch('superagi.models.agent.Agent') as AgentMock,\
patch('superagi.controllers.agent.Project') as ProjectMock,\
patch('superagi.controllers.agent.Tool') as ToolMock,\
patch('superagi.controllers.agent.Toolkit') as ToolkitMock,\
patch('superagi.controllers.agent.AgentSchedule') as AgentScheduleMock,\
patch('superagi.controllers.agent.db') as db_mock:
project_mock = Mock()
ProjectMock.get.return_value = project_mock
# AgentMock.create_agent_with_config.return_value = mock_agent
AgentMock.return_value = mock_agent
tool_mock = Mock()
ToolMock.get_invalid_tools.return_value = []
toolkit_mock = Mock()
ToolkitMock.fetch_tool_ids_from_toolkit.return_value = []
agent_schedule_mock = Mock()
agent_schedule_mock.id = None # id is None before commit
AgentScheduleMock.return_value = mock_schedule
db_mock.session.query.return_value.get.return_value = project_mock
db_mock.session.add.return_value = None
db_mock.session.commit.side_effect = lambda: setattr(agent_schedule_mock, 'id', 1) # id is set after commit
db_mock.session.query.return_value.get.return_value = project_mock
response = client.post("agents/schedule", json=mock_agent_config_schedule)
assert response.status_code == 201
assert response.json() == {
"id": mock_agent.id,
"name": mock_agent.name,
"contentType": "Agents",
"schedule_id": 1
}
def test_create_and_schedule_agent_project_not_found(mock_agent_config_schedule):
with patch('superagi.controllers.agent.db') as mock_db:
# Set up the database query result
mock_db.session.query.return_value.get.return_value = None
# Call the endpoint
response = client.post("agents/schedule", json=mock_agent_config_schedule)
# Verify the HTTP response
assert response.status_code == 404
assert response.json() == {"detail": "Project not found"}
================================================
FILE: tests/unit_tests/controllers/test_agent_execution.py
================================================
from unittest.mock import patch
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.agent_schedule import AgentSchedule
from datetime import datetime
client = TestClient(app)
@pytest.fixture
def mock_patch_schedule_input():
return {
"agent_id": 1,
"start_time": "2023-02-02 01:00:00",
"recurrence_interval": "2 Hours",
"expiry_date": "2023-12-30 01:00:00",
"expiry_runs": -1
}
@pytest.fixture
def mock_schedule():
# Mock schedule data for testing
return AgentSchedule(id=1, agent_id=1, status="SCHEDULED")
# An agent is already scheduled and is simply being updated, we assert for the updated values here
def test_schedule_existing_agent_already_scheduled(mock_patch_schedule_input, mock_schedule):
with patch('superagi.controllers.agent_execution.db') as mock_db:
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_schedule
response = client.post("agentexecutions/schedule", json=mock_patch_schedule_input)
assert response.status_code == 201
assert mock_schedule.start_time == datetime.strptime(mock_patch_schedule_input['start_time'], '%Y-%m-%d %H:%M:%S')
assert mock_schedule.recurrence_interval == mock_patch_schedule_input['recurrence_interval']
assert mock_schedule.expiry_date == datetime.strptime(mock_patch_schedule_input['expiry_date'], '%Y-%m-%d %H:%M:%S')
assert mock_schedule.expiry_runs == mock_patch_schedule_input['expiry_runs']
# The agent isn't scheduled yet and we are scheduling it, we simply assert for a 201 status code and non-null schedule id.
def test_schedule_existing_agent_new_schedule(mock_patch_schedule_input, mock_schedule):
with patch('superagi.controllers.agent_execution.db') as mock_db:
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_schedule
response = client.post("agentexecutions/schedule", json=mock_patch_schedule_input)
assert response.status_code == 201
assert response.json()["schedule_id"] is not None
================================================
FILE: tests/unit_tests/controllers/test_agent_execution_config.py
================================================
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
client = TestClient(app)
@pytest.fixture
def mocks():
# Mock tool kit data for testing
mock_agent = Agent(id=1, name="test_agent", project_id=1, description="testing", agent_workflow_id=1, is_deleted=False)
mock_agent_config = AgentConfiguration(id=1, agent_id=1, key="test_key", value="['test']")
mock_execution = AgentExecution(id=54, agent_id=1, name="test_execution")
mock_execution_config = [AgentExecutionConfiguration(id=64, agent_execution_id=1, key="test_key", value="['test']")]
return mock_agent,mock_agent_config,mock_execution,mock_execution_config
def test_get_agent_execution_configuration_not_found_failure():
with patch('superagi.controllers.agent_execution_config.db') as mock_db:
mock_db.session.query.return_value.filter.return_value.all.return_value = []
mock_db.session.query.return_value.filter.return_value.first.return_value = None
response = client.get("/agent_executions_configs/details/agent_id/1/agent_execution_id/1")
assert response.status_code == 404
assert response.json() == {"detail": "Agent not found"}
def test_get_agent_execution_configuration_success(mocks):
with patch('superagi.controllers.agent_execution_config.db') as mock_db:
mock_agent, mock_agent_config, mock_execution, mock_execution_config = mocks
# Configure the mock objects to return the mock values
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent
mock_db.session.query.return_value.filter.return_value.all.return_value = [mock_agent_config]
mock_db.session.query.return_value.filter.return_value.order_by.return_value.first.return_value = mock_execution
mock_db.session.query.return_value.filter.return_value.all.return_value = mock_execution_config
# Mock the AgentExecution.get_agent_execution_from_id method to return the mock_execution object
with patch('superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id') as mock_get_exec:
mock_get_exec.return_value = mock_execution
response = client.get("/agent_executions_configs/details/agent_id/1/agent_execution_id/1")
assert response.status_code == 200
================================================
FILE: tests/unit_tests/controllers/test_agent_execution_feeds.py
================================================
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
from main import app
from fastapi_sqlalchemy import db
from superagi.controllers.agent_execution_feed import get_agent_execution_feed
@patch('superagi.controllers.agent_execution_feed.db')
def test_get_agent_execution_feed(mock_query):
mock_session = create_autospec(pytest.Session)
AgentExecution = MagicMock()
agent_execution = AgentExecution()
agent_execution.status = "PAUSED"
agent_execution.last_shown_error_id = None
AgentExecutionFeed = MagicMock()
agent_execution_feed = AgentExecutionFeed()
agent_execution_feed.error_message = None
feeds = [agent_execution_feed]
check_auth = MagicMock()
AuthJWT = MagicMock()
check_auth.return_value = AuthJWT
asc = MagicMock()
AgentExecutionPermission = MagicMock()
agent_execution_permission = AgentExecutionPermission()
agent_execution_permission.id = 1
agent_execution_permission.created_at = "2021-12-13T00:00:00"
agent_execution_permission.response = "Yes"
agent_execution_permission.status = "Completed"
agent_execution_permission.tool_name = "Tool1"
agent_execution_permission.question = "Question1"
agent_execution_permission.user_feedback = "Feedback1"
permissions = [agent_execution_permission]
mock_agent_execution = Mock()
mock_query.return_value.filter.return_value.first.return_value = mock_agent_execution
mock_agent_execution_id = 1
assert get_agent_execution_feed(mock_agent_execution_id)
================================================
FILE: tests/unit_tests/controllers/test_agent_template.py
================================================
from unittest.mock import patch, MagicMock
from superagi.models.agent_template import AgentTemplate
from superagi.models.agent_template_config import AgentTemplateConfig
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_success(mock_get_user_org, mock_auth_db, mock_db):
# Create a mock agent template
mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
# mock_agent_goals = AgentTemplateConfig()
# Create a mock edited agent configuration
mock_updated_agent_configs = {
"name": "Updated Agent Template",
"description": "Updated Description",
"agent_configs": {
"agent_workflow": "Don't Maintain Task Queue",
"goal": ["Create a simple pacman game for me.", "Write all files properly."],
"instruction": ["write spec","write code","improve the code","write test"],
"constraints": ["If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.","Ensure the tool and args are as per current plan and reasoning","Exclusively use the tools listed under \"TOOLS\"","REMEMBER to format your response as JSON, using double quotes (\"\") around keys and string values, and commas (,) to separate items in arrays and objects. IMPORTANTLY, to use a JSON object as a string in another JSON object, you need to escape the double quotes."],
"tools": ["Read Email", "Send Email", "Write File"],
"exit": "No exit criterion",
"iteration_interval": 500,
"model": "gpt-4",
"max_iterations": 25,
"permission_type": "God Mode",
"LTM_DB": "Pinecone"
}
}
# Mocking the user organisation
mock_get_user_org.return_value = MagicMock(id=1)
# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
mock_db.session.commit.return_value = None
mock_db.session.add.return_value = None
mock_db.session.flush.return_value = None
mock_agent_template_config = AgentTemplateConfig(agent_template_id = 1, key="goal", value=["Create a simple pacman game for me.", "Write all files properly."])
# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
assert response.status_code == 200
# Verify changes in the mock agent template
assert mock_agent_template.name == "Updated Agent Template"
assert mock_agent_template.description == "Updated Description"
assert mock_agent_template_config.key == "goal"
assert mock_agent_template_config.value == ["Create a simple pacman game for me.", "Write all files properly."]
session_mock.commit.assert_called()
session_mock.flush.assert_called()
@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_failure(mock_get_user_org, mock_auth_db, mock_db):
# Setup: The user organisation exists, but the agent template does not exist.
mock_get_user_org.return_value = MagicMock(id=1)
# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = None
# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json={})
# Verify: The response status code should be 404, indicating that the agent template was not found.
assert response.status_code == 404
assert response.json() == {"detail": "Agent Template not found"}
# Verify: The database commit method should not have been called because the agent template was not found.
session_mock.commit.assert_not_called()
session_mock.flush.assert_not_called()
@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_with_new_config_success(mock_get_user_org, mock_auth_db, mock_db):
# Create a mock agent template
mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
# Create a mock edited agent configuration
mock_updated_agent_configs = {
"name": "Updated Agent Template",
"description": "Updated Description",
"agent_configs": {
"new_config_key": "New config value",
"agent_workflow": "Don't Maintain Task Queue", # This is a new config
}
}
# Mocking the user organisation
mock_get_user_org.return_value = MagicMock(id=1)
# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
mock_db.session.commit.return_value = None
mock_db.session.add.return_value = None
mock_db.session.flush.return_value = None
# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
assert response.status_code == 200
# Verify changes in the mock agent template
assert mock_agent_template.name == "Updated Agent Template"
assert mock_agent_template.description == "Updated Description"
session_mock.commit.assert_called()
session_mock.flush.assert_called()
================================================
FILE: tests/unit_tests/controllers/test_analytics.py
================================================
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
@patch('superagi.controllers.analytics.db')
def test_get_metrics_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.analytics.db') as mock_db, \
patch('superagi.controllers.analytics.AnalyticsHelper') as mock_helper, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_helper().calculate_run_completed_metrics.return_value = {'total_tokens': 10, 'total_calls': 5, 'runs_completed': 2}
response = client.get("/analytics/metrics")
assert response.status_code == 200
assert response.json() == {'total_tokens': 10, 'total_calls': 5, 'runs_completed': 2}
@patch('superagi.controllers.analytics.db')
def test_get_agents_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.analytics.db') as mock_db, \
patch('superagi.controllers.analytics.AnalyticsHelper') as mock_helper, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_helper().fetch_agent_data.return_value = {"agent_details": "mock_details", "model_info": "mock_info"}
response = client.get("/analytics/agents/all")
assert response.status_code == 200
assert response.json() == {"agent_details": "mock_details", "model_info": "mock_info"}
@patch('superagi.controllers.analytics.db')
def test_get_agent_runs_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.analytics.db') as mock_db, \
patch('superagi.controllers.analytics.AnalyticsHelper') as mock_helper, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_helper().fetch_agent_runs.return_value = "mock_agent_runs"
response = client.get("/analytics/agents/1")
assert response.status_code == 200
assert response.json() == "mock_agent_runs"
@patch('superagi.controllers.analytics.db')
def test_get_active_runs_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.analytics.db') as mock_db, \
patch('superagi.controllers.analytics.AnalyticsHelper') as mock_helper, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_helper().get_active_runs.return_value = ["mock_run_1", "mock_run_2"]
response = client.get("/analytics/runs/active")
assert response.status_code == 200
assert response.json() == ["mock_run_1", "mock_run_2"]
@patch('superagi.controllers.analytics.db')
def test_get_tools_user_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.analytics.db') as mock_db, \
patch('superagi.controllers.analytics.ToolsHandler') as mock_handler, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_handler().calculate_tool_usage.return_value = ["tool1", "tool2"]
response = client.get("/analytics/tools/used")
assert response.status_code == 200
assert response.json() == ["tool1", "tool2"]
================================================
FILE: tests/unit_tests/controllers/test_models_controller.py
================================================
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from main import app
from llama_cpp import Llama
from llama_cpp import LlamaGrammar
import llama_cpp
from superagi.helper.llm_loader import LLMLoader
client = TestClient(app)
@patch('superagi.controllers.models_controller.db')
def test_store_api_keys_success(mock_get_db):
request = {
"model_provider": "mock_provider",
"model_api_key": "mock_key"
}
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.post("/models_controller/store_api_keys", json=request)
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_get_api_keys_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/get_api_keys")
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
@patch('superagi.controllers.models_controller.ModelsConfig.fetch_api_key', return_value = {})
def test_get_api_key_success(mock_fetch_api_key, mock_get_db):
params = {
"model_provider": "model"
}
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/get_api_key", params=params)
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_verify_end_point_success(mock_get_db):
with patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/verify_end_point?model_api_key=mock_key&end_point=mock_point&model_provider=mock_provider")
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_store_model_success(mock_get_db):
request = {
"model_name": "mock_model",
"description": "mock_description",
"end_point": "mock_end_point",
"model_provider_id": 1,
"token_limit": 10,
"type": "mock_type",
"version": "mock_version",
"context_length":4096
}
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.post("/models_controller/store_model", json=request)
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_fetch_models_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/fetch_models")
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_fetch_model_details_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/fetch_model/1")
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_fetch_data_success(mock_get_db):
request = {
"model": "model"
}
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.post("/models_controller/fetch_model_data", json=request)
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.models_controller.requests.get') as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
response = client.get("/models_controller/marketplace/list/0")
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/marketplace/list/0")
assert response.status_code == 200
def test_get_local_llm():
with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model:
with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar:
mock_model.create_chat_completion.return_value = {"choices": [{"message": {"content": "Hello!"}}]}
response = client.get("/models_controller/test_local_llm")
assert response.status_code == 200
================================================
FILE: tests/unit_tests/controllers/test_publish_agent.py
================================================
import pytest
from fastapi.testclient import TestClient
from unittest.mock import create_autospec, patch
from main import app
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.organisation import Organisation
from superagi.models.user import User
from sqlalchemy.orm import Session
client = TestClient(app)
@pytest.fixture
def mocks():
# Mock tool kit data for testing
mock_agent = Agent(id=1, name="test_agent", project_id=1, description="testing", agent_workflow_id=1, is_deleted=False)
mock_agent_config = AgentConfiguration(id=1, agent_id=1, key="test_key", value="['test']")
mock_execution = AgentExecution(id=1, agent_id=1, name="test_execution")
mock_execution_config = [AgentExecutionConfiguration(id=1, agent_execution_id=1, key="test_key", value="['test']")]
return mock_agent,mock_agent_config,mock_execution,mock_execution_config
def test_publish_template(mocks):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.get_current_user') as mock_get_user, \
patch('superagi.helper.auth.db') as mock_auth_db,\
patch('superagi.controllers.agent_template.db') as mock_db:
mock_session = create_autospec(Session)
mock_agent, mock_agent_config, mock_execution, mock_execution_config = mocks
mock_session.query.return_value.filter.return_value.first.return_value = mock_agent
mock_session.query.return_value.filter.return_value.all.return_value = [mock_agent_config]
mock_session.query.return_value.filter.return_value.order_by.return_value.first.return_value = mock_execution
mock_session.query.return_value.filter.return_value.all.return_value = mock_execution_config
with patch('superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id') as mock_get_exec:
mock_get_exec.return_value = mock_execution
response = client.post("/agent_templates/publish_template/agent_execution_id/1")
assert response.status_code == 201
================================================
FILE: tests/unit_tests/controllers/test_tool.py
================================================
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
client = TestClient(app)
@pytest.fixture
def mocks():
# Mock tool kit data for testing
user_organisation = Organisation(id=1)
toolkit_1 = Toolkit(
id=1,
name="toolkit_1",
description="None",
show_toolkit=None,
organisation_id=1
)
toolkit_2 = Toolkit(
id=1,
name="toolkit_2",
description="None",
show_toolkit=None,
organisation_id=1
)
user_toolkits = [toolkit_1, toolkit_2]
tool_1 = Tool(
id=1,
name="tool_1",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=1
)
tool_2 = Tool(
id=1,
name="tool_2",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=1
)
tool_3 = Tool(
id=1,
name="tool_3",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=2
)
tools = [tool_1, tool_2, tool_3]
return user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3
def test_get_tools_success(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3 = mocks
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
# Mock the toolkit filtering
mock_db.session.query.return_value.filter.return_value.all.side_effect = [user_toolkits, [tool_1, tool_2],
[tool_3]]
# Call the function
response = client.get("/tools/list")
# Assertions
assert response.status_code == 200
assert response.json() == [{'id': 1, 'name': 'tool_1', 'description': 'Test Tool', 'folder_name': 'test folder',
'file_name': 'test file', 'toolkit_id': 1},
{'id': 1, 'name': 'tool_2', 'description': 'Test Tool', 'folder_name': 'test folder',
'file_name': 'test file', 'toolkit_id': 1},
{'id': 1, 'name': 'tool_3', 'description': 'Test Tool', 'folder_name': 'test folder',
'file_name': 'test file', 'toolkit_id': 2}]
================================================
FILE: tests/unit_tests/controllers/test_tool_config.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.controllers.tool_config import update_tool_config
from superagi.models.organisation import Organisation
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
client = TestClient(app)
@pytest.fixture
def mocks():
# Mock tool kit data for testing
user_organisation = Organisation(id=1)
toolkit_1 = Toolkit(
id=1,
name="toolkit_1",
description="None",
show_toolkit=None,
organisation_id=1
)
toolkit_2 = Toolkit(
id=1,
name="toolkit_2",
description="None",
show_toolkit=None,
organisation_id=1
)
user_toolkits = [toolkit_1, toolkit_2]
tool_config = ToolConfig(
id=1,
key="test_key",
value="test_value",
toolkit_id=1
)
return user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2
# Test cases
def test_update_tool_configs_success():
# Test data
toolkit_name = "toolkit_1"
configs = [
{"key": "config_1", "value": "value_1"},
{"key": "config_2", "value": "value_2"},
]
with patch('superagi.models.toolkit.Toolkit.get_toolkit_from_name') as get_toolkit_from_name, \
patch('superagi.controllers.tool_config.db') as mock_db:
mock_db.query.return_value.filter_by.return_value.first.side_effect = [
# First call to query
MagicMock(
toolkit_id=1, key="config_1", value="old_value_1"
),
# Second call to query
MagicMock(
toolkit_id=1, key="config_2", value="old_value_2"
),
]
result = update_tool_config(toolkit_name, configs)
assert result == {"message": "Tool configs updated successfully"}
def test_get_all_tool_configs_success(mocks):
user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2 = mocks
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = toolkit_1
mock_db.session.query.return_value.filter.return_value.all.side_effect = [
[tool_config]
]
response = client.get(f"/tool_configs/get/toolkit/test_toolkit_1")
# Assertions
assert response.status_code == 200
assert response.json() == [
{
'id': 1,
'key': tool_config.key,
'value': tool_config.value,
'toolkit_id': tool_config.toolkit_id
}
]
def test_get_all_tool_configs_toolkit_not_found(mocks):
user_organisation, _, _, _, _ = mocks
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter.return_value.first.return_value = None
response = client.get(f"/tool_configs/get/toolkit/nonexistent_toolkit")
# Assertions
assert response.status_code == 404
assert response.json() == {'detail': 'ToolKit not found'}
def test_get_tool_config_success(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2 = mocks
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter.return_value.all.return_value = user_toolkits
mock_db.session.query.return_value.filter_by.return_value = toolkit_1
mock_db.session.query.return_value.filter.return_value.first.return_value = tool_config
# Call the function
response = client.get(f"/tool_configs/get/toolkit/{toolkit_1.name}/key/{tool_config.key}")
# Assertions
assert response.status_code == 200
assert response.json() == {
"id": tool_config.id,
"key": tool_config.key,
"value": tool_config.value,
"toolkit_id": tool_config.toolkit_id
}
def test_get_tool_config_unauthorized(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2 = mocks
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
# Mock the toolkit filtering
mock_db.session.query.return_value.filter.return_value.all.return_value = user_toolkits
response = client.get(f"/tool_configs/get/toolkit/{toolkit_2.name}/key/{tool_config.key}")
# Assertions
assert response.status_code == 403
assert response.json() == {"detail": "Unauthorized"}
def test_get_tool_config_not_found(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2 = mocks
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
# Mock the toolkit filtering
mock_db.session.query.return_value.filter.return_value.all.return_value = user_toolkits
mock_db.session.query.return_value.filter_by.return_value = toolkit_1
mock_db.session.query.return_value.filter.return_value.first.return_value = None
# Call the function with a non-existent toolkit
response = client.get(f"/tool_configs/get/toolkit/{toolkit_1.name}/key/{tool_config.key}")
# Assertions
assert response.status_code == 404
assert response.json() == {"detail": "Tool configuration not found"}
================================================
FILE: tests/unit_tests/controllers/test_toolkit.py
================================================
from unittest.mock import patch, call
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
from superagi.types.key_type import ToolConfigKeyType
from superagi.models.toolkit import Toolkit
client = TestClient(app)
@pytest.fixture
def mocks():
# Mock tool kit data for testing
user_organisation = Organisation(id=1)
toolkit_1 = Toolkit(
id=1,
name="toolkit_1",
description="None",
show_toolkit=None,
organisation_id=1
)
toolkit_2 = Toolkit(
id=1,
name="toolkit_2",
description="None",
show_toolkit=None,
organisation_id=1
)
user_toolkits = [toolkit_1, toolkit_2]
tool_1 = Tool(
id=1,
name="tool_1",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=1
)
tool_2 = Tool(
id=1,
name="tool_2",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=1
)
tool_3 = Tool(
id=1,
name="tool_3",
description="Test Tool",
folder_name="test folder",
file_name="test file",
toolkit_id=2
)
tools = [tool_1, tool_2, tool_3]
return user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3
@pytest.fixture
def mock_toolkit_details():
# Mock toolkit details data for testing
toolkit_details = {
"name": "toolkit_1",
"description": "Test Toolkit",
"tool_code_link": "https://example.com/toolkit_1",
"show_toolkit": None,
"tools": [
{
"name": "tool_1",
"description": "Test Tool 1",
"folder_name": "test_folder_1",
"class_name": "TestTool1",
"file_name": "test_tool_1.py"
},
{
"name": "tool_2",
"description": "Test Tool 2",
"folder_name": "test_folder_2",
"class_name": "TestTool2",
"file_name": "test_tool_2.py"
}
],
"configs": [
{
"key": "config_key_1",
"value": "config_value_1",
'key_type': ToolConfigKeyType.STRING,
'is_secret': True,
'is_required': False
},
{
"key": "config_key_2",
"value": "config_value_2",
'key_type': ToolConfigKeyType.FILE,
'is_secret': True,
'is_required': False
}
]
}
return toolkit_details
def test_handle_marketplace_operations_list(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3 = mocks
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.toolkit.db') as mock_db, \
patch('superagi.models.toolkit.Toolkit.fetch_marketplace_list') as mock_fetch_marketplace_list, \
patch('superagi.helper.auth.db') as mock_auth_db:
# Set up mock data
mock_db.session.query.return_value.filter.return_value.all.side_effect = [user_toolkits]
mock_fetch_marketplace_list.return_value = [toolkit_1.to_dict(), toolkit_2.to_dict()]
# Call the function
response = client.get("/toolkits/get/list", params={"page": 0})
# Assertions
assert response.status_code == 200
assert response.json() == [
{
"id": 1,
"name": "toolkit_1",
"description": "None",
"show_toolkit": None,
"organisation_id": 1,
"is_installed": True
},
{
"id": 1,
"name": "toolkit_2",
"description": "None",
"show_toolkit": None,
"organisation_id": 1,
"is_installed": True
}
]
def test_install_toolkit_from_marketplace(mock_toolkit_details):
# Mock the database session and query functions
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.models.toolkit.Toolkit.fetch_marketplace_detail') as mock_fetch_marketplace_detail, \
patch('superagi.models.toolkit.Toolkit.add_or_update') as mock_add_or_update, \
patch('superagi.models.tool.Tool.add_or_update') as mock_tool_add_or_update, \
patch('superagi.controllers.toolkit.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.models.tool_config.ToolConfig.add_or_update') as mock_tool_config_add_or_update:
# Set up mock data and behavior
mock_get_user_org.return_value = Organisation(id=1)
mock_fetch_marketplace_detail.return_value = mock_toolkit_details
mock_add_or_update.return_value = Toolkit(id=1, name=mock_toolkit_details['name'],
description=mock_toolkit_details['description'])
# Call the function
response = client.get("/toolkits/get/install/toolkit_1")
# Assertions
assert response.status_code == 200
assert response.json() == {"message": "ToolKit installed successfully"}
# Verify the function calls
mock_fetch_marketplace_detail.assert_called_once_with(search_str="details", toolkit_name="toolkit_1")
================================================
FILE: tests/unit_tests/controllers/test_update_agent_config_table.py
================================================
import pytest
from unittest.mock import patch, Mock
from superagi.models.agent_config import AgentConfiguration
from superagi.controllers.types.agent_execution_config import AgentRunIn
def test_update_existing_toolkits():
agent_id = 1
updated_details = AgentRunIn(
agent_workflow="test", constraints=["c1", "c2"], toolkits=[1, 2],
tools=[1, 2, 3], exit="exit", iteration_interval=1,
model="test", permission_type="p", LTM_DB="LTM", max_iterations=100
)
# Mock AgentConfiguration instance for the agent_configs list
existing_toolkits_config = Mock(spec=AgentConfiguration)
existing_toolkits_config.key = "toolkits"
existing_toolkits_config.value = [3, 4]
agent_configs = [existing_toolkits_config]
mock_session = Mock()
# Mock the query filter behavior for existing configurations
mock_session.query().filter().all.return_value = agent_configs
result = AgentConfiguration.update_agent_configurations_table(mock_session, agent_id, updated_details)
#Check whether the value gets updated or not
assert existing_toolkits_config.value == '[1, 2]'
assert mock_session.commit.called_once()
assert result == "Details updated successfully"
================================================
FILE: tests/unit_tests/controllers/test_user.py
================================================
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from main import app
from superagi.models.user import User
client = TestClient(app)
# Define a fixture for an authenticated user
@pytest.fixture
def authenticated_user():
# Create a mock user object with necessary attributes
user = User()
# Set user attributes
user.id = 1 # User ID
user.username = "testuser" # User's username
user.email = "super6@agi.com" # User's email
user.first_login_source = None # User's first login source
user.token = "mock-jwt-token"
return user
# Test case for updating first login source when it's not set
def test_update_first_login_source(authenticated_user):
with patch('superagi.helper.auth.db') as mock_auth_db:
source = "github" # Specify the source you want to set
mock_auth_db.session.query.return_value.filter.return_value.first.return_value = authenticated_user
response = client.post(f"users/first_login_source/{source}", headers={"Authorization": f"Bearer {authenticated_user.token}"})
# Verify the HTTP response
assert response.status_code == 200
assert "first_login_source" in response.json() # Check if the "first_login_source" field is in the response
assert response.json()["first_login_source"] == "github" # Check if the "source" field equals "github"
================================================
FILE: tests/unit_tests/helper/__init__.py
================================================
================================================
FILE: tests/unit_tests/helper/test_agent_schedule_helper.py
================================================
import pytest
from unittest.mock import patch, MagicMock, call
from superagi.helper.agent_schedule_helper import AgentScheduleHelper
from superagi.models.agent_schedule import AgentSchedule
from datetime import datetime, timedelta
@patch('superagi.helper.agent_schedule_helper.parse_interval_to_seconds')
@patch('superagi.models.agent_schedule.AgentSchedule')
@patch('superagi.helper.agent_schedule_helper.Session')
@patch('superagi.helper.agent_schedule_helper.datetime')
def test_update_next_scheduled_time(mock_datetime, mock_session, mock_agent_schedule, mock_parse_interval_to_seconds):
mock_datetime.now.return_value = datetime(2022, 1, 1, 10, 0)
# Mock agent object
mock_agent = MagicMock()
mock_agent.start_time = datetime(2022, 1, 1, 1, 0)
mock_agent.next_scheduled_time = datetime(2022, 1, 1, 1, 0)
mock_agent.recurrence_interval = '5 Minutes'
mock_agent.status = 'SCHEDULED'
mock_agent_schedule.return_value = mock_agent
# Mock the return value of the session query
mock_session.return_value.query.return_value.filter.return_value.all.return_value = [mock_agent]
mock_parse_interval_to_seconds.return_value = 300
# Call the method
helperObj = AgentScheduleHelper()
helperObj.update_next_scheduled_time()
# Assert that the mocks were called as expected
mock_session.assert_called_once()
mock_session.return_value.query.assert_called_once()
mock_session.return_value.query.return_value.filter.assert_called()
mock_session.return_value.query.return_value.filter.return_value.all.assert_called_once()
mock_parse_interval_to_seconds.assert_called_once_with('5 Minutes')
assert mock_agent.status == 'SCHEDULED'
@patch('superagi.helper.agent_schedule_helper.AgentScheduleHelper._AgentScheduleHelper__create_execution_name_for_scheduling')
@patch('superagi.helper.agent_schedule_helper.AgentScheduleHelper._AgentScheduleHelper__should_execute_agent')
@patch('superagi.helper.agent_schedule_helper.AgentScheduleHelper._AgentScheduleHelper__can_remove_agent')
@patch('superagi.helper.agent_schedule_helper.AgentScheduleHelper._AgentScheduleHelper__execute_schedule')
@patch('superagi.helper.agent_schedule_helper.parse_interval_to_seconds')
@patch('superagi.helper.agent_schedule_helper.AgentSchedule')
@patch('superagi.helper.agent_schedule_helper.Session')
@patch('superagi.helper.agent_schedule_helper.datetime')
def test_run_scheduled_agents(
mock_datetime,
mock_session,
mock_agent_schedule,
mock_parse_interval_to_seconds,
mock_execute_schedule,
mock_can_remove_agent,
mock_should_execute_agent,
mock_create_execution_name_for_scheduling
):
# Mocking current datetime
mock_datetime.now.return_value = datetime(2022, 1, 1, 10, 0)
# Mocking agent object
mock_agent = MagicMock(spec=AgentSchedule)
mock_agent.next_scheduled_time = datetime(2022, 1, 1, 9, 55)
mock_agent.status = 'SCHEDULED'
mock_agent.recurrence_interval = '5 Minutes'
mock_agent.agent_id = 'agent_1'
# Mocking the return value of the session query
mock_session.return_value.query.return_value.filter.return_value.all.return_value = [mock_agent]
mock_parse_interval_to_seconds.return_value = 300
mock_should_execute_agent.return_value = True
mock_can_remove_agent.return_value = False
mock_create_execution_name_for_scheduling.return_value = 'Run 01 January 2022 10:00'
# Call the method
helper = AgentScheduleHelper()
helper.run_scheduled_agents()
# Assert that the mocks were called as expected
mock_session.assert_called_once_with()
mock_session.return_value.query.assert_called_once_with(mock_agent_schedule)
mock_session.return_value.query.return_value.filter.assert_called_once()
mock_session.return_value.query.return_value.filter.return_value.all.assert_called_once()
mock_parse_interval_to_seconds.assert_has_calls([call('5 Minutes')])
mock_should_execute_agent.assert_called_once_with(mock_agent, mock_agent.recurrence_interval)
mock_can_remove_agent.assert_called_once_with(mock_agent, mock_agent.recurrence_interval)
mock_execute_schedule.assert_has_calls([call(
mock_should_execute_agent.return_value,
mock_parse_interval_to_seconds.return_value,
mock_session(),
mock_agent,
mock_create_execution_name_for_scheduling.return_value
)])
mock_create_execution_name_for_scheduling.assert_called_once_with(mock_agent.agent_id)
================================================
FILE: tests/unit_tests/helper/test_calendar_date.py
================================================
import unittest
from unittest.mock import MagicMock
from datetime import datetime, timezone
import pytz
from superagi.helper.calendar_date import CalendarDate
class TestCalendarDate(unittest.TestCase):
def setUp(self):
self.cd = CalendarDate()
self.service = MagicMock()
self.service.calendars().get().execute.return_value = {'timeZone': 'Asia/Kolkata'}
def test_get_time_zone(self):
time_zone = self.cd._get_time_zone(self.service)
self.assertEqual(time_zone, 'Asia/Kolkata')
def test_convert_to_utc(self):
# Create a datetime object for midnight of January 1st, 2023.
local_datetime = datetime(2023, 1, 1)
# Use the 'US/Pacific' timezone for this example.
local_tz = pytz.timezone('US/Pacific')
# Call the function to convert the local datetime to UTC.
utc_datetime = self.cd._convert_to_utc(local_datetime, local_tz)
# Check that the converted datetime is correct.
# Note: The 'US/Pacific' timezone is 8 hours behind UTC during standard time.
expected_utc_datetime = datetime(2023, 1, 1, 8, 0)
expected_utc_datetime = pytz.timezone('GMT').localize(expected_utc_datetime)
assert utc_datetime == expected_utc_datetime
def test_string_to_datetime(self):
date_str = '2022-01-01'
date_format = '%Y-%m-%d'
date_obj = datetime.strptime(date_str, date_format)
self.assertEqual(date_obj, self.cd._string_to_datetime(date_str, date_format))
def test_localize_daterange(self):
start_date, end_date = '2022-01-01', '2022-01-02'
start_time, end_time = '10:00:00', '12:00:00'
local_tz = pytz.timezone('Asia/Kolkata')
start_datetime_utc, end_datetime_utc = self.cd._localize_daterange(start_date, end_date, start_time, end_time,
local_tz)
self.assertEqual(start_datetime_utc, datetime(2022, 1, 1, 4, 30, tzinfo=timezone.utc))
self.assertEqual(end_datetime_utc, datetime(2022, 1, 2, 6, 30, tzinfo=timezone.utc))
def test_datetime_to_string(self):
date_time = datetime(2022, 1, 1, 0, 0, 0)
date_format = '%Y-%m-%d'
date_str = '2022-01-01'
self.assertEqual(date_str, self.cd._datetime_to_string(date_time, date_format))
def test_get_date_utc(self):
start_date, end_date = '2022-01-01', '2022-01-02'
start_time, end_time = '10:00:00', '12:00:00'
date_utc = {
"start_datetime_utc": "2022-01-01T04:30:00.000000Z",
"end_datetime_utc": "2022-01-02T06:30:00.000000Z"
}
result = self.cd.get_date_utc(start_date, end_date, start_time, end_time, self.service)
self.assertEqual(date_utc, result)
def test_create_event_dates(self):
start_date, end_date = '2022-01-01', '2022-01-02'
start_time, end_time = '10:00:00', '12:00:00'
date_utc = {
"start_datetime_utc": "2022-01-01T04:30:00.000000Z",
"end_datetime_utc": "2022-01-02T06:30:00.000000Z",
"timeZone": "Asia/Kolkata"
}
result = self.cd.create_event_dates(self.service, start_date, start_time, end_date, end_time)
self.assertEqual(date_utc, result)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/helper/test_error_handling.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.helper.error_handler import ErrorHandler
def test_handle_error():
session = Mock()
agent_id = 1
agent_execution_id = 2
error_message = 'Test error'
mock_query = Mock()
mock_query.filter().first.return_value = AgentExecution(id=agent_execution_id)
session.query.return_value = mock_query
ErrorHandler.handle_openai_errors(session, agent_id, agent_execution_id, error_message)
session.query.assert_called_once_with(AgentExecution)
================================================
FILE: tests/unit_tests/helper/test_feed_parser.py
================================================
import unittest
from datetime import datetime
from superagi.helper.feed_parser import parse_feed
from superagi.models.agent_execution_feed import AgentExecutionFeed
class TestParseFeed(unittest.TestCase):
def test_parse_feed_system(self):
current_time = datetime.now()
sample_feed = AgentExecutionFeed(
id=2, agent_execution_id=100, agent_id=200, role="user",
feed='System message', updated_at=current_time
)
result = parse_feed(sample_feed)
self.assertEqual(result['feed'], sample_feed.feed, "Incorrect output from parse_feed function for system role")
self.assertEqual(result['role'], sample_feed.role, "Incorrect output from parse_feed function for system role")
================================================
FILE: tests/unit_tests/helper/test_github_helper.py
================================================
import base64
import unittest
from unittest.mock import patch, MagicMock
from superagi.helper.github_helper import GithubHelper
class TestGithubHelper(unittest.TestCase):
@patch('requests.get')
def test_check_repository_visibility(self, mock_get):
# Create response mock
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {'private': False}
mock_get.return_value = mock_resp
gh = GithubHelper('access_token', 'username')
visibility = gh.check_repository_visibility('owner', 'repo')
self.assertEqual(visibility, False)
mock_get.assert_called_once_with(
"https://api.github.com/repos/owner/repo",
headers={"Authorization": "Token access_token", "Accept": "application/vnd.github.v3+json"}
)
@patch('requests.get')
def test_get_file_path(self, mock_get):
gh = GithubHelper('access_token', 'username')
path = gh.get_file_path('test.txt', 'dir')
self.assertEqual(path, 'dir/test.txt')
@patch('requests.get')
def test_search_repo(self, mock_get):
# Create response mock
mock_resp = MagicMock()
mock_resp.raise_for_status.return_value = None
mock_resp.json.return_value = 'data'
mock_get.return_value = mock_resp
gh = GithubHelper('access_token', 'username')
data = gh.search_repo('owner', 'repo', 'test.txt', '')
self.assertEqual(data, 'data')
mock_get.assert_called_once_with(
'https://api.github.com/repos/owner/repo/contents/test.txt',
headers={"Authorization": "token access_token", "Content-Type": "application/vnd.github+json"}
)
@patch('requests.get')
@patch('requests.patch')
def test_sync_branch(self, mock_patch, mock_get):
# Create response mocks
mock_get_resp = MagicMock()
mock_get_resp.json.return_value = {'commit': {'sha': 'sha'}}
mock_get.return_value = mock_get_resp
mock_patch_resp = MagicMock()
mock_patch_resp.status_code = 200
mock_patch.return_value = mock_patch_resp
gh = GithubHelper('access_token', 'username')
gh.sync_branch('owner', 'repo', 'base', 'head', {'header': 'value'})
mock_get.assert_called_once_with(
'https://api.github.com/repos/owner/repo/branches/base',
headers={'header': 'value'}
)
mock_patch.assert_called_once_with(
'https://api.github.com/repos/username/repo/git/refs/heads/head',
json={'sha': 'sha', 'force': True},
headers={'header': 'value'}
)
@patch('requests.get')
@patch('requests.post')
def test_create_branch(self, mock_post, mock_get):
# Create response mocks
mock_get_resp = MagicMock()
mock_get_resp.json.return_value = {'object': {'sha': 'sha'}}
mock_get.return_value = mock_get_resp
mock_post_resp = MagicMock()
mock_post_resp.status_code = 201
mock_post.return_value = mock_post_resp
gh = GithubHelper('access_token', 'username')
status_code = gh.create_branch('repo', 'base', 'head', {'header': 'value'})
self.assertEqual(status_code, 201)
mock_get.assert_called_once_with(
'https://api.github.com/repos/username/repo/git/refs/heads/base',
headers={'header': 'value'}
)
mock_post.assert_called_once_with(
'https://api.github.com/repos/username/repo/git/refs',
json={'ref': 'refs/heads/head', 'sha': 'sha'},
headers={'header': 'value'}
)
@patch('requests.post')
def test_make_fork(self, mock_post):
# Create response mock
mock_resp = MagicMock()
mock_resp.status_code = 202
mock_post.return_value = mock_resp
gh = GithubHelper('access_token', 'username')
with patch.object(GithubHelper, 'sync_branch') as mock_sync:
status_code = gh.make_fork('owner', 'repo', 'base', {'header': 'value'})
self.assertEqual(status_code, 202)
mock_post.assert_called_once_with(
'https://api.github.com/repos/owner/repo/forks',
headers={'header': 'value'}
)
mock_sync.assert_called_once_with('owner', 'repo', 'base', 'base', {'header': 'value'})
@patch('requests.delete')
def test_delete_file(self, mock_delete):
# Create response mock
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_delete.return_value = mock_resp
gh = GithubHelper('access_token', 'username')
with patch.object(GithubHelper, 'get_sha', return_value='sha') as mock_sha:
status_code = gh.delete_file('repo', 'test.txt', 'path', 'message', 'head', {'header': 'value'})
self.assertEqual(status_code, 200)
mock_sha.assert_called_once_with('username', 'repo', 'test.txt', 'path')
mock_delete.assert_called_once_with(
'https://api.github.com/repos/username/repo/contents/path/test.txt',
json={'message': 'message', 'sha': 'sha', 'branch': 'head'},
headers={'header': 'value'}
)
@patch('requests.post')
def test_create_pull_request(self, mock_post):
# Create response mock
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_post.return_value = mock_resp
gh = GithubHelper('access_token', 'username')
status_code = gh.create_pull_request('owner', 'repo', 'head', 'base', {'header': 'value'})
self.assertEqual(status_code, 201)
mock_post.assert_called_once_with(
'https://api.github.com/repos/owner/repo/pulls',
json={
'title': 'Pull request by username',
'body': 'Please review and merge this change.',
'head': 'username:head',
'head_repo': 'repo',
'base': 'base'
},
headers={'header': 'value'}
)
@patch('requests.get')
def test_get_pull_request_content_success(self, mock_get):
mock_get.return_value.status_code = 200
mock_get.return_value.text = "some_content"
github_api = GithubHelper('access_token', 'username')
result = github_api.get_pull_request_content("owner", "repo", 1)
self.assertEqual(result, "some_content")
@patch('requests.get')
def test_get_pull_request_content_not_found(self, mock_get):
mock_get.return_value.status_code = 404
github_api = GithubHelper('access_token', 'username')
result = github_api.get_pull_request_content("owner", "repo", 1)
self.assertIsNone(result)
@patch('requests.get')
def test_get_latest_commit_id_of_pull_request(self, mock_get):
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = [{"sha": "123"}, {"sha": "456"}]
github_api = GithubHelper('access_token', 'username')
result = github_api.get_latest_commit_id_of_pull_request("owner", "repo", 1)
self.assertEqual(result, "456")
@patch('requests.post')
def test_add_line_comment_to_pull_request(self, mock_post):
mock_post.return_value.status_code = 201
mock_post.return_value.json.return_value = {"id": 1, "body": "comment"}
github_api = GithubHelper('access_token', 'username')
result = github_api.add_line_comment_to_pull_request("owner", "repo", 1, "commit_id", "file_path", 1, "comment")
self.assertEqual(result, {"id": 1, "body": "comment"})
# ... more tests for other methods
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/helper/test_json_cleaner.py
================================================
from superagi.helper.json_cleaner import JsonCleaner
import pytest
def test_extract_json_section():
test_str = 'Before json {"key":"value"} after json'
result = JsonCleaner.extract_json_section(test_str)
assert result == '{"key":"value"}'
def test_remove_escape_sequences():
test_str = r'This is a test\nstring'
result = JsonCleaner.remove_escape_sequences(test_str)
assert result == 'This is a test\nstring'
def test_balance_braces():
test_str = '{{{{"key":"value"}}'
result = JsonCleaner.balance_braces(test_str)
assert result == '{{{{"key":"value"}}}}'
def test_balance_braces():
test_str = '{"key": false}'
result = JsonCleaner.clean_boolean(test_str)
assert result == '{"key": False}'
================================================
FILE: tests/unit_tests/helper/test_resource_helper.py
================================================
from unittest.mock import patch, MagicMock
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.resource import Resource
def test_make_written_file_resource(mocker):
mocker.patch('os.getcwd', return_value='/')
mocker.patch('os.makedirs', return_value=None)
mocker.patch('os.path.getsize', return_value=1000)
mocker.patch('os.path.splitext', return_value=("", ".txt"))
mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['FILE', '/', '/', 'FILE'])
mock_agent = Agent(id=1, name='TestAgent')
mock_agent_execution = AgentExecution(id=1, name='TestExecution')
session = MagicMock()
with patch('superagi.helper.resource_helper.logger') as logger_mock:
session.query.return_value.filter_by.return_value.first.return_value = None
# Create a Resource object
resource = Resource(
name='test.txt',
path='/test.txt',
storage_type='FILE',
size=1000,
type='application/txt',
channel='OUTPUT',
agent_id=1,
agent_execution_id=1
)
# Mock the session.add() method to return the created Resource object
session.add.return_value = resource
result = ResourceHelper.make_written_file_resource('test.txt', mock_agent, mock_agent_execution, session)
assert result.name == 'test.txt'
assert result.path == '/test.txt'
assert result.storage_type == 'FILE'
assert result.size == 1000
assert result.type == 'application/txt'
assert result.channel == 'OUTPUT'
assert result.agent_id == 1
def test_get_resource_path(mocker):
mocker.patch('os.getcwd', return_value='/')
mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/'])
result = ResourceHelper.get_resource_path('test.txt')
assert result == '/test.txt'
def test_get_agent_resource_path(mocker):
mocker.patch('os.getcwd', return_value='/')
mocker.patch('os.makedirs')
mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/'])
mock_agent = Agent(id=1, name='TestAgent')
mock_agent_execution = AgentExecution(id=1, name='TestExecution')
result = ResourceHelper.get_agent_write_resource_path('test.txt', mock_agent, mock_agent_execution)
assert result == '/test.txt'
def test_get_formatted_agent_level_path():
agent = Agent(id=1, name="TestAgent")
path = "/data/{agent_id}/file.txt"
formatted_path = ResourceHelper.get_formatted_agent_level_path(agent, path)
expected_path = "/data/TestAgent_1/file.txt"
assert formatted_path == expected_path
def test_get_formatted_agent_execution_level_path():
agent_execution = AgentExecution(id=1, name="TestExecution")
path = "/results/{agent_execution_id}/output.csv"
formatted_path = ResourceHelper.get_formatted_agent_execution_level_path(agent_execution, path)
expected_path = "/results/TestExecution_1/output.csv"
assert formatted_path == expected_path
================================================
FILE: tests/unit_tests/helper/test_s3_helper.py
================================================
import json
import pytest
from unittest.mock import MagicMock, patch
from botocore.exceptions import NoCredentialsError
from fastapi import HTTPException
from superagi.helper.s3_helper import S3Helper
@pytest.fixture()
def s3helper_object():
return S3Helper()
def test__get_s3_client(s3helper_object):
with patch('superagi.helper.s3_helper.get_config', return_value='test') as mock_get_config:
s3_client = s3helper_object._S3Helper__get_s3_client()
mock_get_config.assert_any_call('AWS_ACCESS_KEY_ID')
mock_get_config.assert_any_call('AWS_SECRET_ACCESS_KEY')
@pytest.mark.parametrize('have_creds, raises', [(True, False), (False, True)])
def test_upload_file(s3helper_object, have_creds, raises):
s3helper_object.s3.upload_fileobj = MagicMock()
s3helper_object.s3.upload_fileobj.side_effect = NoCredentialsError() if not have_creds else None
if raises:
with pytest.raises(HTTPException):
s3helper_object.upload_file('file', 'path')
else:
s3helper_object.upload_file('file', 'path')
@pytest.mark.parametrize('have_creds, raises', [(True, False), (False, True)])
def test_get_json_file(s3helper_object, have_creds, raises):
# Mock 'get_object' method from s3 client
s3helper_object.s3.get_object = MagicMock()
# Mocked JSON contents with their 'Body' key as per real response
mock_json_file = { 'Body': MagicMock() }
mock_json_file['Body'].read = MagicMock(return_value=bytes(json.dumps("content_of_json"), 'utf-8'))
# Case when we do have credentials but 'get_object' raises an error
if not raises:
s3helper_object.s3.get_object.return_value = mock_json_file
else:
s3helper_object.s3.get_object.side_effect = NoCredentialsError()
# Mocking a path to the file
mock_path = "mock_path"
if raises:
with pytest.raises(HTTPException):
s3helper_object.get_json_file(mock_path)
else:
content = s3helper_object.get_json_file(mock_path)
# Assert that 'get_object' was called with our mocked path
s3helper_object.s3.get_object.assert_called_with(Bucket=s3helper_object.bucket_name, Key=mock_path)
assert content == "content_of_json" # Assert we got our mocked JSON content back
def test_check_file_exists_in_s3(s3helper_object):
s3helper_object.s3.list_objects_v2 = MagicMock(return_value={})
assert s3helper_object.check_file_exists_in_s3('path') == False
s3helper_object.s3.list_objects_v2 = MagicMock(return_value={'Contents':[]})
assert s3helper_object.check_file_exists_in_s3('path') == True
@pytest.mark.parametrize('http_status, expected_result, raises', [(200, 'file_content', False), (500, None, True)])
def test_read_from_s3(s3helper_object, http_status, expected_result, raises):
s3helper_object.s3.get_object = MagicMock(
return_value={'ResponseMetadata': {'HTTPStatusCode': http_status},
'Body': MagicMock(read=lambda: bytes(expected_result, 'utf-8'))}
)
if raises:
with pytest.raises(Exception):
s3helper_object.read_from_s3('path')
else:
assert s3helper_object.read_from_s3('path') == expected_result
@pytest.mark.parametrize('http_status, expected_result, raises',
[(200, b'file_content', False),
(500, None, True)])
def test_read_binary_from_s3(s3helper_object, http_status, expected_result, raises):
s3helper_object.s3.get_object = MagicMock(
return_value={'ResponseMetadata': {'HTTPStatusCode': http_status},
'Body': MagicMock(read=lambda: (expected_result))}
)
if raises:
with pytest.raises(Exception):
s3helper_object.read_binary_from_s3('path')
else:
assert s3helper_object.read_binary_from_s3('path') == expected_result
def test_delete_file_success(s3helper_object):
s3helper_object.s3.delete_object = MagicMock()
try:
s3helper_object.delete_file('path')
except:
pytest.fail("Unexpected Exception !")
def test_delete_file_fail(s3helper_object):
s3helper_object.s3.delete_object = MagicMock(side_effect=Exception())
with pytest.raises(HTTPException):
s3helper_object.delete_file('path')
def test_list_files_from_s3(s3helper_object):
s3helper_object.s3.list_objects_v2 = MagicMock(return_value={
'Contents': [{'Key': 'path/to/file1.txt'}, {'Key': 'path/to/file2.jpg'}]
})
file_list = s3helper_object.list_files_from_s3('path/to/')
assert len(file_list) == 2
assert 'path/to/file1.txt' in file_list
assert 'path/to/file2.jpg' in file_list
def test_list_files_from_s3_no_contents(s3helper_object):
s3helper_object.s3.list_objects_v2 = MagicMock(return_value={})
with pytest.raises(Exception):
s3helper_object.list_files_from_s3('path/to/')
def test_list_files_from_s3_raises_exception(s3helper_object):
s3helper_object.s3.list_objects_v2 = MagicMock(side_effect=Exception("An error occurred"))
with pytest.raises(Exception):
s3helper_object.list_files_from_s3('path/to/')
================================================
FILE: tests/unit_tests/helper/test_time_helper.py
================================================
from superagi.helper.time_helper import get_time_difference, parse_interval_to_seconds
import pytest
def test_get_time_difference():
timestamp1 = "2023-06-26 17:31:08.884322"
timestamp2 = "2023-06-27 03:57:42.038497"
expected_result = {
"years": 0,
"months": 0,
"days": 0,
"hours": 10,
"minutes": 26
}
assert get_time_difference(timestamp1, timestamp2) == expected_result
def test_parse_interval_to_seconds():
assert parse_interval_to_seconds("2 Minutes") == 120
assert parse_interval_to_seconds("3 Hours") == 10800
assert parse_interval_to_seconds("1 Days") == 86400
assert parse_interval_to_seconds("7 Weeks") == 4233600
assert parse_interval_to_seconds("2 Months") == 5184000
================================================
FILE: tests/unit_tests/helper/test_token_counter.py
================================================
import pytest
from typing import List
from unittest.mock import MagicMock, patch
from superagi.types.common import BaseMessage
from superagi.helper.token_counter import TokenCounter
from superagi.models.models import Models
@pytest.fixture()
def setup_model_token_limit():
model_token_limit_dict = {
"gpt-3.5-turbo-0301": 4032,
"gpt-4-0314": 8092,
"gpt-3.5-turbo": 4032,
"gpt-4": 8092,
"gpt-3.5-turbo-16k": 16184,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768
}
return model_token_limit_dict
@patch.object(Models, "fetch_model_tokens", autospec=True)
def test_token_limit(mock_fetch_model_tokens, setup_model_token_limit):
mock_fetch_model_tokens.return_value = setup_model_token_limit
tc = TokenCounter(MagicMock(), 1)
for model, expected_tokens in setup_model_token_limit.items():
assert tc.token_limit(model) == expected_tokens
assert tc.token_limit("non_existing_model") == 8092
def test_count_message_tokens():
message_list = [{'content': 'Hello, How are you doing ?'}, {'content': 'I am good. How about you ?'}]
BaseMessage.list_from_dicts = MagicMock(return_value=message_list)
expected_token_count = TokenCounter.count_message_tokens(BaseMessage.list_from_dicts(message_list), "gpt-3.5-turbo-0301")
assert expected_token_count == 26
expected_token_count = TokenCounter.count_message_tokens(BaseMessage.list_from_dicts(message_list), "non_existing_model")
assert expected_token_count == 26
def test_count_text_tokens():
# You might need to adjust the hardcoded values in the TokenCounter.count_text_tokens function
# and update the expected tokens accordingly if the function logic is changed.
text = "You are a helpful assistant."
assert TokenCounter.count_text_tokens(text) == 10
text = "What is your name?"
assert TokenCounter.count_text_tokens(text) == 9
================================================
FILE: tests/unit_tests/helper/test_tool_helper.py
================================================
import os
import shutil
import sys
from pathlib import Path
from unittest.mock import patch, Mock
import pytest
from superagi.helper.tool_helper import (
parse_github_url,
load_module_from_file,
extract_repo_name,
get_readme_content_from_code_link, download_tool, handle_tools_import, compare_toolkit, compare_configs,
compare_tools
)
def setup_function():
os.makedirs('target_folder', exist_ok=True)
# Teardown function to remove the directory
def teardown_function():
shutil.rmtree('target_folder')
@pytest.fixture
def mock_requests_get(monkeypatch):
class MockResponse:
def __init__(self, content, status_code):
self.content = content
self.status_code = status_code
self.text = content.decode() if content is not None else None
def mock_get(url):
if url == 'https://api.github.com/repos/owner/repo/zipball/main':
return MockResponse(b'ZIP_CONTENT', 200)
elif url == 'https://raw.githubusercontent.com/username/repo/main/README.MD':
return MockResponse(b'README_CONTENT', 200)
elif url == 'https://raw.githubusercontent.com/username/repo/main/README.md':
return MockResponse(b'README_CONTENT', 200)
else:
return MockResponse(None, 404)
monkeypatch.setattr('requests.get', mock_get)
def test_parse_github_url():
github_url = 'https://github.com/owner/repo'
expected_result = 'owner/repo/main'
assert parse_github_url(github_url) == expected_result
def test_load_module_from_file(tmp_path):
current_dir = os.getcwd()
file_path = Path(current_dir) / 'test_module.py'
# Corrected code with proper indentation
file_content = '''
def hello():
return 'Hello, world!'
'''
file_path.write_text(file_content)
module = load_module_from_file(file_path)
assert module.hello() == 'Hello, world!'
# Delete the test_module.py file
file_path.unlink()
def test_get_readme_content_from_code_link(mock_requests_get):
tool_code_link = 'https://github.com/username/repo'
expected_result = 'README_CONTENT'
assert get_readme_content_from_code_link(tool_code_link) == expected_result
def test_extract_repo_name():
repo_link = 'https://github.com/username/repo'
expected_result = 'repo'
assert extract_repo_name(repo_link) == expected_result
@patch('requests.get')
@patch('zipfile.ZipFile')
def test_download_tool(mock_zip, mock_get):
mock_response = Mock()
mock_response.content = b'file content'
mock_get.return_value = mock_response
# Mock zipfile to return a list of files
mock_zip.return_value.__enter__.return_value.namelist.return_value = ['owner-repo/somefile.txt']
download_tool('https://github.com/owner/repo', 'target_folder')
# Assert that the function made the correct HTTP request
mock_get.assert_called_once_with('https://api.github.com/repos/owner/repo/zipball/main')
# Assert zipfile was opened correctly
mock_zip.assert_called_once_with('target_folder/tool.zip', 'r')
def test_handle_tools_import():
with patch('superagi.config.config.get_config') as mock_get_config, \
patch('os.listdir') as mock_listdir, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_get_config.return_value = "superagi/tools"
mock_listdir.return_value = "test_tool"
initial_path_length = len(sys.path)
handle_tools_import()
assert len(sys.path), initial_path_length + 2
def test_compare_tools():
tool1 = {"name": "Tool A", "description": "This is Tool A"}
tool2 = {"name": "Tool A", "description": "This is Tool A"}
assert not compare_tools(tool1, tool2)
tool1 = {"name": "Tool A", "description": "This is Tool A"}
tool2 = {"name": "Tool B", "description": "This is Tool A"}
assert compare_tools(tool1, tool2)
tool1 = {"name": "Tool A", "description": "This is Tool A"}
tool2 = {"name": "Tool A", "description": "This is Tool B"}
assert compare_tools(tool1, tool2)
def test_compare_configs():
config1 = {"key": "config_key"}
config2 = {"key": "config_key"}
assert not compare_configs(config1, config2)
config1 = {"key": "config_key_1"}
config2 = {"key": "config_key_2"}
assert compare_configs(config1, config2)
def test_compare_toolkit():
toolkit1 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool A"}],
"configs": [{"key": "config_key"}]
}
toolkit2 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool A"}],
"configs": [{"key": "config_key"}]
}
assert not compare_toolkit(toolkit1, toolkit2)
toolkit1 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool A"}],
"configs": [{"key": "config_key"}]
}
toolkit2 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool B"}],
"configs": [{"key": "config_key"}]
}
assert compare_toolkit(toolkit1, toolkit2)
toolkit1 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool A"}],
"configs": [{"key": "config_key_1"}]
}
toolkit2 = {
"description": "Toolkit Description",
"show_toolkit": True,
"name": "Toolkit",
"tool_code_link": "https://example.com/toolkit",
"tools": [{"name": "Tool A", "description": "This is Tool A"}],
"configs": [{"key": "config_key_2"}]
}
assert compare_toolkit(toolkit1, toolkit2)
================================================
FILE: tests/unit_tests/helper/test_twitter_helper.py
================================================
import unittest
from unittest.mock import Mock, patch
from requests.models import Response
from requests_oauthlib import OAuth1Session
from superagi.helper.twitter_helper import TwitterHelper
class TestSendTweets(unittest.TestCase):
@patch.object(OAuth1Session, 'post')
def test_send_tweets_success(self, mock_post):
# Prepare test data and mocks
test_params = {"status": "Hello, Twitter!"}
test_creds = Mock()
test_oauth = OAuth1Session(test_creds.api_key)
# Mock successful posting
resp = Response()
resp.status_code = 200
mock_post.return_value = resp
# Call the method under test
response = TwitterHelper().send_tweets(test_params, test_creds)
# Assert the post request was called correctly
test_oauth.post.assert_called_once_with(
"https://api.twitter.com/2/tweets",
json=test_params)
# Assert the response is correct
self.assertEqual(response.status_code, 200)
@patch.object(OAuth1Session, 'post')
def test_send_tweets_failure(self, mock_post):
# Prepare test data and mocks
test_params = {"status": "Hello, Twitter!"}
test_creds = Mock()
test_oauth = OAuth1Session(test_creds.api_key)
# Mock unsuccessful posting
resp = Response()
resp.status_code = 400
mock_post.return_value = resp
# Call the method under test
response = TwitterHelper().send_tweets(test_params, test_creds)
# Assert the post request was called correctly
test_oauth.post.assert_called_once_with(
"https://api.twitter.com/2/tweets",
json=test_params)
# Assert the response is correct
self.assertEqual(response.status_code, 400)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/helper/test_twitter_tokens.py
================================================
import unittest
from unittest.mock import patch, Mock, MagicMock
from typing import NamedTuple
import ast
from sqlalchemy.orm import Session
from superagi.helper.twitter_tokens import Creds, TwitterTokens
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
import time
import http.client
class TestCreds(unittest.TestCase):
def test_init(self):
creds = Creds('api_key', 'api_key_secret', 'oauth_token', 'oauth_token_secret')
self.assertEqual(creds.api_key, 'api_key')
self.assertEqual(creds.api_key_secret, 'api_key_secret')
self.assertEqual(creds.oauth_token, 'oauth_token')
self.assertEqual(creds.oauth_token_secret, 'oauth_token_secret')
class TestTwitterTokens(unittest.TestCase):
twitter_tokens = TwitterTokens(Session)
def setUp(self):
self.mock_session = Mock(spec=Session)
self.twitter_tokens = TwitterTokens(session=self.mock_session)
def test_init(self):
self.assertEqual(self.twitter_tokens.session, self.mock_session)
def test_percent_encode(self):
self.assertEqual(self.twitter_tokens.percent_encode("#"), "%23")
def test_gen_nonce(self):
self.assertEqual(len(self.twitter_tokens.gen_nonce()), 32)
@patch.object(time, 'time', return_value=1234567890)
@patch.object(http.client, 'HTTPSConnection')
@patch('superagi.helper.twitter_tokens.TwitterTokens.gen_nonce', return_value=123456) # Replace '__main__' with actual module name
@patch('superagi.helper.twitter_tokens.TwitterTokens.percent_encode', return_value="encoded") # Replace '__main__' with actual module name
def test_get_request_token(self, mock_percent_encode, mock_gen_nonce, mock_https_connection, mock_time):
response_mock = Mock()
response_mock.read.return_value = b'oauth_token=test_token&oauth_token_secret=test_secret'
mock_https_connection.return_value.getresponse.return_value = response_mock
api_data = {"api_key": "test_key", "api_secret": "test_secret"}
expected_result = {'oauth_token': 'test_token', 'oauth_token_secret': 'test_secret'}
self.assertEqual(self.twitter_tokens.get_request_token(api_data), expected_result)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/unit_tests/helper/test_webhooks.py
================================================
import json
from unittest.mock import Mock, patch
import pytest
from superagi.helper.webhook_manager import WebHookManager
from superagi.models.webhooks import Webhooks
@pytest.fixture
def mock_session():
return Mock()
@pytest.fixture
def mock_agent_execution():
return Mock()
@pytest.fixture
def mock_agent():
return Mock()
@pytest.fixture
def mock_webhook():
return Mock()
@pytest.fixture
def mock_org():
org_mock = Mock()
org_mock.id = "mock_org_id"
return org_mock
def test_agent_status_change_callback(
mock_session, mock_agent_execution, mock_agent, mock_org, mock_webhook
):
curr_status = "NEW_STATUS"
old_status = "OLD_STATUS"
mock_agent_id = "mock_agent_id"
mock_org_id = "mock_org_id"
# Create a mock instance of AgentExecution and set its attributes
mock_agent_execution_instance = Mock()
mock_agent_execution_instance.agent_id = "mock_agent_id"
# Create a mock instance of Agent and set its attributes
mock_agent_instance = Mock()
mock_agent_instance.get_agent_organisation.return_value = mock_org
# Create a mock instance of Webhooks and set its attributes
mock_webhook_instance = Mock(spec=Webhooks)
mock_webhook_instance.filters = {"status": ["PAUSED", "RUNNING"]}
# Set up session.query().filter().all() to return the mock_webhook_instance
mock_session.query.return_value.filter.return_value.all.return_value = [mock_webhook_instance]
# Patch required functions/methods
with patch(
'superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id',
return_value=mock_agent_execution_instance
), patch(
'superagi.models.agent.Agent.get_agent_from_id',
return_value=mock_agent_instance
), patch(
'requests.post',
return_value=Mock(status_code=200) # Mock the status_code response
) as mock_post, patch(
'json.dumps'
) as mock_json_dumps:
# Create the WebHookManager instance
web_hook_manager = WebHookManager(mock_session)
# Call the function
web_hook_manager.agent_status_change_callback(
mock_agent_execution_instance, curr_status, old_status
)
assert mock_agent_execution_instance.agent_status_change_callback
================================================
FILE: tests/unit_tests/jobs/__init__.py
================================================
================================================
FILE: tests/unit_tests/jobs/conftest.py
================================================
# content of conftest.py
def pytest_configure(config):
import sys
sys._called_from_test = True
def pytest_unconfigure(config):
import sys # This was missing from the manual
del sys._called_from_test
================================================
FILE: tests/unit_tests/jobs/test_resource_summary.py
================================================
================================================
FILE: tests/unit_tests/jobs/test_scheduling_executor.py
================================================
import pytest
from unittest.mock import patch, MagicMock, ANY, PropertyMock
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.jobs.scheduling_executor import ScheduledAgentExecutor
from datetime import datetime
@patch('superagi.worker.execute_agent.delay')
@patch('superagi.jobs.scheduling_executor.Session')
@patch('superagi.models.agent.Agent')
@patch('superagi.jobs.scheduling_executor.AgentWorkflow')
@patch('superagi.models.agent_execution.AgentExecution')
def test_execute_scheduled_agent(AgentExecutionMock, AgentWorkflowMock, AgentMock, SessionMock, execute_agent_delay_mock):
# Arrange
agent_id = 1
name = 'Test Agent'
# session setup
session_mock = MagicMock()
SessionMock.return_value = session_mock
# agent setup
mock_agent = MagicMock(spec=Agent)
mock_agent.id = agent_id
session_mock.query.return_value.get.return_value = mock_agent
db_agent_execution_mock = AgentExecution(status="RUNNING",last_execution_time=datetime.now(),agent_id=agent_id, name=name, num_of_calls=0, num_of_tokens=0, current_agent_step_id=1)
type(db_agent_execution_mock).id = PropertyMock(return_value=123)
AgentExecutionMock.return_value = db_agent_execution_mock
# Create a ScheduledAgentExecutor object and then call execute_scheduled_agent
executor = ScheduledAgentExecutor()
# Act
executor.execute_scheduled_agent(agent_id, name)
# Assert
assert session_mock.query.called
assert session_mock.commit.called
execute_agent_delay_mock.assert_called_once_with(db_agent_execution_mock.id, ANY)
args, _ = execute_agent_delay_mock.call_args
assert isinstance(args[0], int)
assert isinstance(args[1], datetime)
================================================
FILE: tests/unit_tests/llms/__init__.py
================================================
================================================
FILE: tests/unit_tests/llms/test_google_palm.py
================================================
from unittest.mock import patch
from superagi.llms.google_palm import GooglePalm
@patch('superagi.llms.google_palm.palm')
def test_chat_completion(mock_palm):
# Arrange
model = 'models/text-bison-001'
api_key = 'test_key'
palm_instance = GooglePalm(api_key, model=model)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100
mock_palm.generate_text.return_value.result = 'Sure, I can help with that.'
# Act
result = palm_instance.chat_completion(messages, max_tokens)
# Assert
assert result == {"response": mock_palm.generate_text.return_value, "content": 'Sure, I can help with that.'}
mock_palm.generate_text.assert_called_once_with(
model=model,
prompt='You are a helpful assistant.',
temperature=palm_instance.temperature,
candidate_count=palm_instance.candidate_count,
top_k=palm_instance.top_k,
top_p=palm_instance.top_p,
max_output_tokens=int(max_tokens)
)
def test_verify_access_key():
model = 'models/text-bison-001'
api_key = 'test_key'
palm_instance = GooglePalm(api_key, model=model)
result = palm_instance.verify_access_key()
assert result is False
================================================
FILE: tests/unit_tests/llms/test_hugging_face.py
================================================
import os
from unittest.mock import patch, Mock
from unittest import TestCase
import requests
import json
from superagi.llms.hugging_face import HuggingFace
from superagi.config.config import get_config
from superagi.llms.utils.huggingface_utils.tasks import Tasks, TaskParameters
from superagi.llms.utils.huggingface_utils.public_endpoints import ACCOUNT_VERIFICATION_URL
class TestHuggingFace(TestCase):
# @patch.object(requests, "post")
# def test_chat_completion(self, mock_post):
# # Arrange
# api_key = 'test_api_key'
# model = 'test_model'
# end_point = 'test_end_point'
# hf_instance = HuggingFace(api_key, model=model, end_point=end_point)
# messages = [{"role": "system", "content": "You are a helpful assistant."}]
# mock_post.return_value = Mock()
# mock_post.return_value.content = b'{"0": {"generated_text": "Sure, I can help with that."}}'
#
# # Act
# result = hf_instance.chat_completion(messages)
#
# # Assert
# mock_post.assert_called_with(
# end_point,
# headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
# data=json.dumps({
# "inputs": "You are a helpful assistant.\nThe responses in json schema:",
# "parameters": TaskParameters().get_params(Tasks.TEXT_GENERATION),
# "options": {
# "use_cache": False,
# "wait_for_model": True,
# }
# })
# )
# assert result == {"response": {0: {"generated_text": "Sure, I can help with that."}}, "content": "Sure, I can help with that."}
@patch.object(requests, "get")
def test_verify_access_key(self, mock_get):
# Arrange
api_key = 'test_api_key'
model = 'test_model'
end_point = 'test_end_point'
hf_instance = HuggingFace(api_key, model=model, end_point=end_point)
mock_get.return_value.status_code = 200
# Act
result = hf_instance.verify_access_key()
# Assert
mock_get.assert_called_with(ACCOUNT_VERIFICATION_URL, headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"})
assert result is True
@patch.object(requests, "post")
def test_verify_end_point(self, mock_post):
# Arrange
api_key = 'test_api_key'
model = 'test_model'
end_point = 'test_end_point'
hf_instance = HuggingFace(api_key, model=model, end_point=end_point)
mock_post.return_value.json.return_value = {"valid_response": "valid"}
# Act
result = hf_instance.verify_end_point()
# Assert
mock_post.assert_called_with(
end_point,
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
data=json.dumps({"inputs": "validating end_point"})
)
assert result == {"valid_response": "valid"}
================================================
FILE: tests/unit_tests/llms/test_model_factory.py
================================================
import pytest
from unittest.mock import Mock
from superagi.llms.google_palm import GooglePalm
from superagi.llms.hugging_face import HuggingFace
from superagi.llms.llm_model_factory import get_model, build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
# Fixtures for the mock objects
@pytest.fixture
def mock_openai():
return Mock(spec=OpenAi)
@pytest.fixture
def mock_replicate():
return Mock(spec=Replicate)
@pytest.fixture
def mock_google_palm():
return Mock(spec=GooglePalm)
@pytest.fixture
def mock_hugging_face():
return Mock(spec=HuggingFace)
@pytest.fixture
def mock_replicate():
return Mock(spec=Replicate)
@pytest.fixture
def mock_google_palm():
return Mock(spec=GooglePalm)
@pytest.fixture
def mock_hugging_face():
return Mock(spec=HuggingFace)
# Test build_model_with_api_key function
def test_build_model_with_openai(mock_openai, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai)
model = build_model_with_api_key('OpenAi', 'fake_key')
mock_openai.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_replicate(mock_replicate, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate)
model = build_model_with_api_key('Replicate', 'fake_key')
mock_replicate.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_openai(mock_openai, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai) # Replace 'your_module' with the actual module name
model = build_model_with_api_key('OpenAi', 'fake_key')
mock_openai.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_replicate(mock_replicate, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate) # Replace 'your_module' with the actual module name
model = build_model_with_api_key('Replicate', 'fake_key')
mock_replicate.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_google_palm(mock_google_palm, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.GooglePalm', mock_google_palm) # Replace 'your_module' with the actual module name
model = build_model_with_api_key('Google Palm', 'fake_key')
mock_google_palm.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_hugging_face(mock_hugging_face, monkeypatch):
monkeypatch.setattr('superagi.llms.llm_model_factory.HuggingFace', mock_hugging_face) # Replace 'your_module' with the actual module name
model = build_model_with_api_key('Hugging Face', 'fake_key')
mock_hugging_face.assert_called_once_with(api_key='fake_key')
assert isinstance(model, Mock)
def test_build_model_with_unknown_provider(capsys): # capsys is a built-in pytest fixture for capturing print output
model = build_model_with_api_key('Unknown', 'fake_key')
assert model is None
captured = capsys.readouterr()
assert "Unknown provider." in captured.out
================================================
FILE: tests/unit_tests/llms/test_open_ai.py
================================================
import openai
import pytest
from unittest.mock import MagicMock, patch
from superagi.llms.openai import OpenAi, MAX_RETRY_ATTEMPTS
@patch('superagi.llms.openai.openai')
def test_chat_completion(mock_openai):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100
mock_chat_response = MagicMock()
mock_chat_response.choices[0].message = {"content": "I'm here to help!"}
mock_openai.ChatCompletion.create.return_value = mock_chat_response
# Act
result = openai_instance.chat_completion(messages, max_tokens)
# Assert
assert result == {"response": mock_chat_response, "content": "I'm here to help!"}
mock_openai.ChatCompletion.create.assert_called_once_with(
n=openai_instance.number_of_results,
model=model,
messages=messages,
temperature=openai_instance.temperature,
max_tokens=max_tokens,
top_p=openai_instance.top_p,
frequency_penalty=openai_instance.frequency_penalty,
presence_penalty=openai_instance.presence_penalty
)
@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_rate_limit_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100
mock_openai.ChatCompletion.create.side_effect = openai.error.RateLimitError("Rate limit exceeded")
# Mock sleep time
mock_wait_random_exponential.return_value = 0.1
# Act
result = openai_instance.chat_completion(messages, max_tokens)
# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Rate limit exceeded"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS
@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_timeout_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100
mock_openai.ChatCompletion.create.side_effect = openai.error.Timeout("Timeout occured")
# Mock sleep time
mock_wait_random_exponential.return_value = 0.1
# Act
result = openai_instance.chat_completion(messages, max_tokens)
# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Timeout occured"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS
@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_try_again_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100
mock_openai.ChatCompletion.create.side_effect = openai.error.TryAgain("Try Again")
# Mock sleep time
mock_wait_random_exponential.return_value = 0.1
# Act
result = openai_instance.chat_completion(messages, max_tokens)
# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Try Again"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS
def test_verify_access_key():
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
result = openai_instance.verify_access_key()
assert result is False
================================================
FILE: tests/unit_tests/llms/test_replicate.py
================================================
import os
from unittest.mock import patch
import pytest
import requests
from unittest import TestCase
from superagi.llms.replicate import Replicate
from superagi.config.config import get_config
class TestReplicate(TestCase):
@patch('os.environ')
@patch('replicate.run')
def test_chat_completion(self, mock_replicate_run, mock_os_environ):
# Arrange
api_key = 'test_api_key'
model = 'test_model'
version = 'test_version'
max_length=1000
temperature=0.7
candidate_count=1
top_k=40
top_p=0.95
rep_instance = Replicate(api_key, model=model, version=version, max_length=max_length, temperature=temperature,
candidate_count=candidate_count, top_k=top_k, top_p=top_p)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
mock_replicate_run.return_value = iter(['Sure, I can help with that.'])
# Act
result = rep_instance.chat_completion(messages)
# Assert
assert result == {"response": ['Sure, I can help with that.'], "content": 'Sure, I can help with that.'}
@patch.object(requests, "get")
def test_verify_access_key(self, mock_get):
# Arrange
api_key = 'test_api_key'
model = 'test_model'
version = 'test_version'
rep_instance = Replicate(api_key, model=model, version=version)
mock_get.return_value.status_code = 200
# Act
result = rep_instance.verify_access_key()
# Assert
assert result is True
mock_get.assert_called_with("https://api.replicate.com/v1/collections", headers={"Authorization": "Token " + api_key})
@patch.object(requests, "get")
def test_verify_access_key_false(self, mock_get):
# Arrange
api_key = 'test_api_key'
model = 'test_model'
version = 'test_version'
rep_instance = Replicate(api_key, model=model, version=version)
mock_get.return_value.status_code = 400
# Act
result = rep_instance.verify_access_key()
# Assert
assert result is False
================================================
FILE: tests/unit_tests/models/__init__.py
================================================
================================================
FILE: tests/unit_tests/models/test_agent.py
================================================
from unittest.mock import create_autospec
from sqlalchemy.orm import Session
from superagi.models.agent import Agent
from unittest.mock import patch
def test_get_agent_from_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample agent ID
agent_id = 1
# Create a mock agent object to be returned by the session query
mock_agent = Agent(id=agent_id, name="Test Agent", project_id=1, description="Agent for testing")
# Configure the session query to return the mock agent
session.query.return_value.filter.return_value.first.return_value = mock_agent
# Call the method under test
agent = Agent.get_agent_from_id(session, agent_id)
# Assert that the returned agent object matches the mock agent
assert agent == mock_agent
def test_get_active_agent_by_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample agent ID
agent_id = 1
# Create a mock agent object to be returned by the session query
mock_agent = Agent(id=agent_id, name="Test Agent", project_id=1, description="Agent for testing",is_deleted=False)
# Configure the session query to return the mock agent
session.query.return_value.filter.return_value.first.return_value = mock_agent
# Call the method under test
agent = Agent.get_active_agent_by_id(session, agent_id)
# Assert that the returned agent object matches the mock agent
assert agent == mock_agent
assert agent.is_deleted == False
def test_eval_tools_key():
key = "tools"
value = "[1, 2, 3]"
result = Agent.eval_agent_config(key, value)
assert result == [1, 2, 3]
================================================
FILE: tests/unit_tests/models/test_agent_execution.py
================================================
from datetime import datetime
from unittest.mock import create_autospec, patch, Mock
import pytest
from pytest_mock import mocker
from sqlalchemy.orm import Session
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
from superagi.models.workflows.iteration_workflow import IterationWorkflow
def test_get_agent_execution_from_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample agent ID
agent_execution_id = 1
# Create a mock agent execution object to be returned by the session query
mock_agent_execution = AgentExecution(id=agent_execution_id, name="Test Execution")
# Configure the session query to return the mock agent
session.query.return_value.filter.return_value.first.return_value = mock_agent_execution
# Call the method under test
agent = AgentExecution.get_agent_execution_from_id(session, agent_execution_id)
# Assert that the returned agent object matches the mock agent
assert agent == mock_agent_execution
@pytest.fixture
def mock_session(mocker):
# Create a mock for the session object
mock_session = mocker.Mock()
return mock_session
def test_update_tokens(mock_session):
# Create a mock agent execution
mock_execution = AgentExecution(
id=1,
status='RUNNING',
name='Agent',
agent_id=1,
last_execution_time=datetime.now(),
num_of_calls=1,
num_of_tokens=100,
current_agent_step_id=1
)
# Mock the query response
mock_session.query.return_value.filter.return_value.first.return_value = mock_execution
# Call the method
AgentExecution.update_tokens(mock_session, 1, 50)
# Check that the attributes were updated
assert mock_execution.num_of_calls == 2
assert mock_execution.num_of_tokens == 150
def test_assign_next_step_id(mock_session, mocker):
# Create a mock agent execution and workflow step
mock_execution = AgentExecution(
id=1,
status='RUNNING',
name='Agent',
agent_id=1,
last_execution_time=datetime.now(),
num_of_calls=1,
num_of_tokens=100,
current_agent_step_id=1
)
mock_step = AgentWorkflowStep(id=2, action_type='ITERATION_WORKFLOW', action_reference_id=1)
mock_trigger_step = IterationWorkflow(id=3)
# Mock the query responses
mock_session.query.return_value.filter.return_value.first.return_value = mock_execution
mocker.patch.object(AgentWorkflowStep, 'find_by_id', return_value=mock_step)
mocker.patch.object(IterationWorkflow, 'fetch_trigger_step_id', return_value=mock_trigger_step)
# Call the method
AgentExecution.assign_next_step_id(mock_session, 1, 2)
# Check that the attributes were updated
assert mock_execution.current_agent_step_id == 2
assert mock_execution.iteration_workflow_step_id == 3
def test_get_execution_by_agent_id_and_status():
session = create_autospec(Session)
# Create a sample agent execution ID
agent_execution_id = 1
# Create a mock agent execution object to be returned by the session query
mock_agent_execution = AgentExecution(id=agent_execution_id, name="Test Execution", status="RUNNING")
# Configure the session query to return the mock agent
session.query.return_value.filter.return_value.first.return_value = mock_agent_execution
# Call the method under test
agent_execution = AgentExecution.get_execution_by_agent_id_and_status(session, agent_execution_id,"RUNNING")
# Assert that the returned agent object matches the mock agent
assert agent_execution == mock_agent_execution
assert agent_execution.status == "RUNNING"
@pytest.fixture
def mock_session(mocker):
return mocker.MagicMock()
================================================
FILE: tests/unit_tests/models/test_agent_execution_config.py
================================================
import unittest
from unittest.mock import MagicMock, patch, call
from distlib.util import AND
from superagi.models.agent_execution_config import AgentExecutionConfiguration
class TestAgentExecutionConfiguration(unittest.TestCase):
def setUp(self):
self.session = MagicMock()
self.execution = MagicMock()
self.execution.id = 1
def test_fetch_configuration(self):
test_db_response = [MagicMock(key="goal", value="['test_goal']"),
MagicMock(key="instruction", value="['test_instruction']"),
MagicMock(key="tools", value="[1]")]
self.session.query.return_value.filter_by.return_value.all.return_value = test_db_response
result = AgentExecutionConfiguration.fetch_configuration(self.session, self.execution)
expected_result = {"goal": ["test_goal"], "instruction": ["test_instruction"], "tools":[1]}
self.assertDictEqual(result, expected_result)
def test_eval_agent_config(self):
key = "goal"
value = "['test_goal']"
result = AgentExecutionConfiguration.eval_agent_config(key, value)
self.assertEqual(result, ["test_goal"])
================================================
FILE: tests/unit_tests/models/test_agent_execution_feed.py
================================================
import pytest
from unittest.mock import Mock, create_autospec
from sqlalchemy.orm import Session
from superagi.models.agent_execution_feed import AgentExecutionFeed
def test_get_last_tool_response():
mock_session = create_autospec(Session)
agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system')
agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system')
mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2]
result = AgentExecutionFeed.get_last_tool_response(mock_session, 2)
assert result == agent_execution_feed_1.feed # as agent_execution_feed_1 should be the latest based on created_at
def test_get_last_tool_response_with_tool_name():
mock_session = create_autospec(Session)
agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system')
agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system')
mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2]
result = AgentExecutionFeed.get_last_tool_response(mock_session, 2, "test2")
assert result == agent_execution_feed_2.feed
================================================
FILE: tests/unit_tests/models/test_agent_schedule.py
================================================
from unittest.mock import create_autospec
from sqlalchemy.orm import Session
from superagi.models.agent_schedule import AgentSchedule
def test_find_by_agent_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample agent ID
agent_id = 1
# Create a mock agent schedule object to be returned by the session query
mock_agent_schedule = AgentSchedule(id=1,agent_id=agent_id, start_time="2023-08-10 12:17:00", recurrence_interval="2 Minutes", expiry_runs=2)
# Configure the session query to return the mock agent
session.query.return_value.filter.return_value.first.return_value = mock_agent_schedule
# Call the method under test
agent_schedule = AgentSchedule.find_by_agent_id(session, agent_id)
# Assert that the returned agent object matches the mock agent
assert agent_schedule == mock_agent_schedule
================================================
FILE: tests/unit_tests/models/test_agent_template.py
================================================
from unittest.mock import Mock, patch
import requests
from superagi.models.agent_template import AgentTemplate
from superagi.models.workflows.agent_workflow import AgentWorkflow
def test_to_dict():
agent_template = AgentTemplate(id=1, name='test', description='desc')
result = agent_template.to_dict()
assert result == {'id': 1, 'name': 'test', 'description': 'desc'}
def test_to_json():
agent_template = AgentTemplate(id=1, name='test', description='desc')
result = agent_template.to_json()
assert result == '{"id": 1, "name": "test", "description": "desc"}'
def test_from_json():
json_data = '{"id": 1, "name": "test", "description": "desc"}'
agent_template = AgentTemplate.from_json(json_data)
assert agent_template.id == 1
assert agent_template.name == 'test'
assert agent_template.description == 'desc'
def test_main_keys():
keys = AgentTemplate.main_keys()
assert isinstance(keys, list)
assert 'goal' in keys
assert 'instruction' in keys
@patch.object(requests, 'get')
def test_fetch_marketplace_list(mock_get):
mock_get.return_value = Mock(status_code=200, json=lambda: [{'id': 1, 'name': 'test', 'description': 'desc'}])
result = AgentTemplate.fetch_marketplace_list('test', 1)
assert len(result) == 1
assert result[0]['id'] == 1
@patch.object(requests, 'get')
def test_fetch_marketplace_detail(mock_get):
mock_get.return_value = Mock(status_code=200, json=lambda: {'id': 1, 'name': 'test', 'description': 'desc'})
result = AgentTemplate.fetch_marketplace_detail(1)
assert result['id'] == 1
assert result['name'] == 'test'
assert result['description'] == 'desc'
def test_eval_agent_config():
assert AgentTemplate.eval_agent_config('name', 'test') == 'test'
assert AgentTemplate.eval_agent_config('project_id', '1') == 1
assert AgentTemplate.eval_agent_config('goal', '["goal1", "goal2"]') == ["goal1", "goal2"]
@patch('superagi.models.agent_template.AgentTemplate.fetch_marketplace_detail')
@patch('sqlalchemy.orm.Session')
def test_clone_agent_template_from_marketplace(mock_session, mock_fetch_marketplace_detail):
mock_fetch_marketplace_detail.return_value = {
"id": 1,
"name": "test",
"description": "desc",
"agent_workflow_name": "workflow1",
"configs": {
"config1": {"value": "value1"},
"config2": {"value": "value2"}
}
}
mock_session.query.return_value.filter.return_value.first.return_value = AgentWorkflow(id=1, name='workflow1')
agent_template = AgentTemplate.clone_agent_template_from_marketplace(mock_session, 1, 1)
assert isinstance(agent_template, AgentTemplate)
assert agent_template.organisation_id == 1
assert agent_template.name == 'test'
assert agent_template.description == 'desc'
================================================
FILE: tests/unit_tests/models/test_agent_workflow.py
================================================
import pytest
from unittest.mock import MagicMock
from sqlalchemy.orm import Session
from superagi.models.workflows.agent_workflow import AgentWorkflow
@pytest.fixture
def mock_session():
session = MagicMock(spec=Session)
session.query.return_value.filter.return_value.first.return_value = MagicMock(spec=AgentWorkflow)
return session
def test_find_by_name(mock_session):
result = AgentWorkflow.find_by_name(mock_session, 'workflow_name')
mock_session.query.assert_called_once_with(AgentWorkflow)
assert result.__class__ == AgentWorkflow
def test_find_or_create_by_name_new(mock_session):
mock_session.query.return_value.filter.return_value.first.return_value = None
result = AgentWorkflow.find_or_create_by_name(mock_session, 'workflow_name', 'description')
mock_session.add.assert_called_once()
assert result.__class__ == AgentWorkflow
def test_find_or_create_by_name_exists(mock_session):
result = AgentWorkflow.find_or_create_by_name(mock_session, 'workflow_name', 'description')
mock_session.add.assert_not_called()
assert result.__class__ == AgentWorkflow
def test_fetch_trigger_step_id(mock_session):
result = AgentWorkflow.fetch_trigger_step_id(mock_session, 1)
mock_session.query.assert_called_once()
assert result is not None
================================================
FILE: tests/unit_tests/models/test_agent_workflow_step.py
================================================
import pytest
from unittest.mock import MagicMock, patch, Mock
from sqlalchemy.orm import Session
import json
from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep
@patch('sqlalchemy.orm.Session.query')
def test_find_by_id(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=AgentWorkflowStep)
result = AgentWorkflowStep.find_by_id(Session(), 1)
assert isinstance(result, AgentWorkflowStep)
@patch('sqlalchemy.orm.Session.query')
def test_find_by_unique_id(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=AgentWorkflowStep)
result = AgentWorkflowStep.find_by_unique_id(Session(), '1')
assert isinstance(result, AgentWorkflowStep)
def test_from_json():
data = {
'id': 1,
'agent_workflow_id': 1,
'unique_id': '1',
'step_type': 'TRIGGER',
'action_type': 'TOOL',
'action_reference_id': 1,
'next_steps': []
}
result = AgentWorkflowStep.from_json(json.dumps(data))
assert isinstance(result, AgentWorkflowStep)
def test_to_dict():
step = AgentWorkflowStep(
id=1,
agent_workflow_id=1,
unique_id='1',
step_type='TRIGGER',
action_type='TOOL',
action_reference_id=1,
next_steps=[]
)
result = step.to_dict()
assert isinstance(result, dict)
assert result['id'] == 1
assert result['agent_workflow_id'] == 1
assert result['unique_id'] == '1'
assert result['step_type'] == 'TRIGGER'
assert result['action_type'] == 'TOOL'
assert result['action_reference_id'] == 1
assert result['next_steps'] == []
@patch('sqlalchemy.orm.Session.add')
@patch('sqlalchemy.orm.Session.commit')
@patch('sqlalchemy.orm.Session.query')
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStepTool.find_or_create_tool')
def test_find_or_create_tool_workflow_step(mock_find_or_create_tool, mock_query, mock_commit, mock_add):
mock_find_or_create_tool.return_value = MagicMock(id=2)
mock_query.return_value.filter.return_value.first.return_value = None # to simulate workflow_step not exists yet
session = MagicMock(spec=Session)
session.query = mock_query
result = AgentWorkflowStep.find_or_create_tool_workflow_step(
session=session,
agent_workflow_id=1,
unique_id='1',
tool_name='test_tool',
input_instruction='test_instruction'
)
assert isinstance(result, AgentWorkflowStep)
assert result.agent_workflow_id == 1
assert result.unique_id == '1'
assert result.action_type == 'TOOL'
assert result.action_reference_id == 2
assert result.next_steps == []
@patch('sqlalchemy.orm.Session.commit')
@patch('sqlalchemy.orm.Session.query')
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStepTool.find_or_create_tool')
def test_find_or_create_tool_workflow_step_exists(mock_find_or_create_tool, mock_query, mock_commit):
existing_workflow_step = MagicMock(spec=AgentWorkflowStep)
mock_find_or_create_tool.return_value = MagicMock(id=2)
mock_query.return_value.filter.return_value.first.return_value = existing_workflow_step
session = MagicMock(spec=Session)
session.query = mock_query
result = AgentWorkflowStep.find_or_create_tool_workflow_step(
session=session,
agent_workflow_id=1,
unique_id='1',
tool_name='test_tool',
input_instruction='test_instruction'
)
assert result == existing_workflow_step
@patch('sqlalchemy.orm.Session.commit')
@patch('sqlalchemy.orm.Session.query')
@patch('superagi.models.workflows.iteration_workflow.IterationWorkflow.find_workflow_by_name')
def test_find_or_create_iteration_workflow_step_exists(mock_find_workflow_by_name, mock_query, mock_commit):
existing_workflow_step = MagicMock(spec=AgentWorkflowStep)
mock_find_workflow_by_name.return_value = MagicMock(id=2)
mock_query.return_value.filter.return_value.first.return_value = existing_workflow_step
session = MagicMock(spec=Session)
session.query = mock_query
result = AgentWorkflowStep.find_or_create_iteration_workflow_step(
session=session,
agent_workflow_id=1,
unique_id='1',
iteration_workflow_name='test_iteration_workflow',
step_type='NORMAL'
)
assert result == existing_workflow_step
@patch('sqlalchemy.orm.Session.commit')
@patch('sqlalchemy.orm.Session.query')
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStep.find_by_id')
def test_add_next_workflow_step(mock_find_by_id, mock_query, mock_commit):
next_workflow_step = MagicMock(spec=AgentWorkflowStep, unique_id='2')
mock_find_by_id.return_value = next_workflow_step
current_step = MagicMock(spec=AgentWorkflowStep, next_steps=[])
mock_query.return_value.filter.return_value.first.return_value = current_step
session = MagicMock(spec=Session)
session.query = mock_query
result = AgentWorkflowStep.add_next_workflow_step(
session=session,
current_agent_step_id=1,
next_step_id=2,
step_response='test_response'
)
assert result == current_step
assert len(result.next_steps) == 1
assert result.next_steps[0]['step_response'] == 'test_response'
assert result.next_steps[0]['step_id'] == '2'
@patch('sqlalchemy.orm.Session.commit')
@patch('sqlalchemy.orm.Session.query')
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStep.find_by_id')
def test_add_next_workflow_step_existing(mock_find_by_id, mock_query, mock_commit):
next_workflow_step = MagicMock(spec=AgentWorkflowStep, unique_id='2')
mock_find_by_id.return_value = next_workflow_step
current_step = MagicMock(spec=AgentWorkflowStep, next_steps=[{"step_response": 'previous_response', "step_id": '2'}])
mock_query.return_value.filter.return_value.first.return_value = current_step
session = MagicMock(spec=Session)
session.query = mock_query
result = AgentWorkflowStep.add_next_workflow_step(
session=session,
current_agent_step_id=1,
next_step_id=2,
step_response='test_response'
)
assert result == current_step
assert len(result.next_steps) == 1
assert result.next_steps[0]['step_response'] == 'test_response'
assert result.next_steps[0]['step_id'] == '2'
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStep.find_by_id')
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStep.find_by_unique_id')
def test_fetch_default_next_step(mock_find_by_unique_id, mock_find_by_id):
current_step = MagicMock(spec=AgentWorkflowStep, next_steps=[{"step_response": 'default', "step_id": '2'}])
next_step = MagicMock(spec=AgentWorkflowStep, unique_id='2')
mock_find_by_id.return_value = current_step
mock_find_by_unique_id.return_value = next_step
session = MagicMock(spec=Session)
result = AgentWorkflowStep.fetch_default_next_step(
session=session,
current_agent_step_id=1,
)
assert result == next_step
@patch('superagi.models.workflows.agent_workflow_step.AgentWorkflowStep.find_by_id')
def test_fetch_default_next_step_none(mock_find_by_id):
current_step = MagicMock(spec=AgentWorkflowStep, next_steps=[{"step_response": 'non-default', "step_id": '2'}])
mock_find_by_id.return_value = current_step
session = MagicMock(spec=Session)
result = AgentWorkflowStep.fetch_default_next_step(
session=session,
current_agent_step_id=1,
)
assert result is None
================================================
FILE: tests/unit_tests/models/test_agent_workflow_step_tool.py
================================================
from unittest.mock import patch, MagicMock
from sqlalchemy.orm import Session
from superagi.models.workflows.agent_workflow_step_tool import AgentWorkflowStepTool
@patch('sqlalchemy.orm.Session.query')
def test_find_by_id(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=AgentWorkflowStepTool)
result = AgentWorkflowStepTool.find_by_id(Session(), 1)
assert isinstance(result, AgentWorkflowStepTool)
@patch('sqlalchemy.orm.Session.add')
@patch('sqlalchemy.orm.Session.query')
def test_find_or_create_tool_new(mock_query, mock_add):
mock_query.return_value.filter_by.return_value.first.return_value = None # simulating tool doesn't exist
session = MagicMock(spec=Session)
session.query = mock_query
tool = AgentWorkflowStepTool.find_or_create_tool(
session=session,
step_unique_id='test_step',
tool_name='test_tool',
input_instruction='test_input',
output_instruction='test_output',
history_enabled=False,
completion_prompt='test_prompt'
)
assert tool.__class__ == AgentWorkflowStepTool
@patch('sqlalchemy.orm.Session.query')
def test_find_or_create_tool_exists(mock_query):
mock_tool = MagicMock(spec=AgentWorkflowStepTool)
mock_query.return_value.filter_by.return_value.first.return_value = mock_tool # simulating tool already exists
session = MagicMock(spec=Session)
session.query = mock_query
tool = AgentWorkflowStepTool.find_or_create_tool(
session=session,
step_unique_id='test_step',
tool_name='test_tool',
input_instruction='test_input',
output_instruction='test_output',
history_enabled=False,
completion_prompt='test_prompt'
)
assert tool == mock_tool
================================================
FILE: tests/unit_tests/models/test_api_key.py
================================================
from unittest.mock import create_autospec
from sqlalchemy.orm import Session
from superagi.models.api_key import ApiKey
def test_get_by_org_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample organization ID
org_id = 1
# Create a mock ApiKey object to be returned by the session query
mock_api_keys = [
ApiKey(id=1, org_id=org_id, key="key1", is_expired=False),
ApiKey(id=2, org_id=org_id, key="key2", is_expired=False),
]
# Configure the session query to return the mock api keys
session.query.return_value.filter.return_value.all.return_value = mock_api_keys
# Call the method under test
api_keys = ApiKey.get_by_org_id(session, org_id)
# Assert that the returned api keys match the mock api keys
assert api_keys == mock_api_keys
def test_get_by_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample api key ID
api_key_id = 1
# Create a mock ApiKey object to be returned by the session query
mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
# Configure the session query to return the mock api key
session.query.return_value.filter.return_value.first.return_value = mock_api_key
# Call the method under test
api_key = ApiKey.get_by_id(session, api_key_id)
# Assert that the returned api key matches the mock api key
assert api_key == mock_api_key
def test_delete_by_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample api key ID
api_key_id = 1
# Create a mock ApiKey object to be returned by the session query
mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
# Configure the session query to return the mock api key
session.query.return_value.filter.return_value.first.return_value = mock_api_key
# Call the method under test
ApiKey.delete_by_id(session, api_key_id)
# Assert that the api key's is_expired attribute is set to True
assert mock_api_key.is_expired == True
# Assert that the session.commit and session.flush methods were called
session.commit.assert_called_once()
session.flush.assert_called_once()
def test_edit_by_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample api key ID and new name
api_key_id = 1
new_name = "New Name"
# Create a mock ApiKey object to be returned by the session query
mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False)
# Configure the session query to return the mock api key
session.query.return_value.filter.return_value.first.return_value = mock_api_key
# Call the method under test
ApiKey.update_api_key(session, api_key_id, new_name)
# Assert that the api key's name attribute is updated
assert mock_api_key.name == new_name
# Assert that the session.commit and session.flush methods were called
session.commit.assert_called_once()
session.flush.assert_called_once()
================================================
FILE: tests/unit_tests/models/test_call_logs.py
================================================
import pytest
from unittest.mock import MagicMock
from superagi.models.call_logs import CallLogs
@pytest.fixture
def mock_session():
session = MagicMock()
session.query.return_value.filter.return_value.first.return_value = None
return session
@pytest.mark.parametrize("agent_execution_name, agent_id, tokens_consumed, tool_used, model, org_id",
[("example_execution", 1, 1, "Test Tool", "Test Model", 1)])
def test_create_call_logs(mock_session, agent_execution_name, agent_id, tokens_consumed, tool_used, model, org_id):
# Arrange
call_log = CallLogs(agent_execution_name=agent_execution_name,
agent_id=agent_id,
tokens_consumed=tokens_consumed,
tool_used=tool_used,
model=model,
org_id=org_id)
# Act
mock_session.add(call_log)
# Assert
mock_session.add.assert_called_once_with(call_log)
@pytest.mark.parametrize("agent_execution_name, agent_id, tokens_consumed, tool_used, model, org_id",
[("example_execution", 1, 1, "Test Tool", "Test Model", 1)])
def test_repr_method_call_logs(mock_session, agent_execution_name, agent_id, tokens_consumed, tool_used, model, org_id):
# Arrange
call_log = CallLogs(agent_execution_name=agent_execution_name,
agent_id=agent_id,
tokens_consumed=tokens_consumed,
tool_used=tool_used,
model=model,
org_id=org_id)
# Act
result = repr(call_log)
# Assert
assert result == (f"CallLogs(id=None, agent_execution_name={agent_execution_name}, "
f"agent_id={agent_id}, tokens_consumed={tokens_consumed}, "
f"tool_used={tool_used}, model={model}, org_id={org_id})")
================================================
FILE: tests/unit_tests/models/test_configuration.py
================================================
import pytest
from fastapi import HTTPException
from unittest.mock import MagicMock
from sqlalchemy.orm import Session
from superagi.models.configuration import Configuration
from superagi.models.agent import Agent
from superagi.models.organisation import Organisation
from superagi.models.project import Project
def test_fetch_configuration():
mock_session = MagicMock(spec=Session)
mock_query = mock_session.query.return_value
mock_query.filter_by.return_value.first.return_value = Configuration(value="test_value")
result = Configuration.fetch_configuration(mock_session, 1, "test_key")
assert result == "test_value"
mock_session.query.assert_called_once_with(Configuration)
mock_query.filter_by.assert_called_once_with(organisation_id=1, key="test_key")
def test_fetch_value_by_agent_id():
mock_session = MagicMock(spec=Session)
mock_query = mock_session.query.return_value
mock_query.filter.return_value.first.side_effect = [
Agent(project_id=1), Project(organisation_id=1), Organisation(id=1), Configuration(value="test_value")
]
result = Configuration.fetch_value_by_agent_id(mock_session, 1, "test_key")
assert result == "test_value"
assert mock_session.query.call_count == 4
def test_fetch_value_by_agent_id_agent_not_found():
mock_session = MagicMock(spec=Session)
mock_query = mock_session.query.return_value
mock_query.filter.return_value.first.return_value = None
with pytest.raises(HTTPException) as exception_info:
Configuration.fetch_value_by_agent_id(mock_session, 1, "test_key")
assert exception_info.value.status_code == 404
assert exception_info.value.detail == "Agent not found"
================================================
FILE: tests/unit_tests/models/test_events.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.models.events import Event
@pytest.fixture
def mock_session():
return MagicMock()
def test_create_event(mock_session):
# Arrange
event_name = "example_event"
event_value = 100
agent_id = 1
org_id = 1
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
event = Event(event_name=event_name, event_value=event_value)
mock_session.add(event)
# Assert
mock_session.add.assert_called_once_with(event)
def test_repr_method_event(mock_session):
# Arrange
event_name = "example_event"
event_value = 100
agent_id = 1
org_id = 1
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
event = Event(event_name=event_name, event_value=event_value)
event_repr = repr(event)
# Assert
assert event_repr == f"Event(id=None, event_name={event_name}, " \
f"event_value={event_value}, " \
f"agent_id=None, " \
f"org_id=None)"
================================================
FILE: tests/unit_tests/models/test_iteration_workflow.py
================================================
from unittest.mock import MagicMock, patch
from sqlalchemy.orm import Session
from superagi.models.workflows.iteration_workflow import IterationWorkflow
@patch('sqlalchemy.orm.Session.query')
def test_find_by_id(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=IterationWorkflow)
session = MagicMock(spec=Session)
session.query = mock_query
result = IterationWorkflow.find_by_id(session, 1)
mock_query.assert_called_once_with(IterationWorkflow)
assert result.__class__ == IterationWorkflow
@patch('sqlalchemy.orm.Session.query')
def test_find_workflow_by_name(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=IterationWorkflow)
session = MagicMock(spec=Session)
session.query = mock_query
result = IterationWorkflow.find_workflow_by_name(session, 'workflow_name')
mock_query.assert_called_once_with(IterationWorkflow)
assert result.__class__ == IterationWorkflow
@patch('sqlalchemy.orm.Session.add')
@patch('sqlalchemy.orm.Session.query')
def test_find_or_create_by_name_new(mock_query, mock_add):
mock_query.return_value.filter.return_value.first.return_value = None
session = MagicMock(spec=Session)
session.query = mock_query
result = IterationWorkflow.find_or_create_by_name(session, 'workflow_name', 'description', False)
assert result.__class__ == IterationWorkflow
@patch('sqlalchemy.orm.Session.add')
@patch('sqlalchemy.orm.Session.query')
def test_find_or_create_by_name_exists(mock_query, mock_add):
mock_query.return_value.filter.return_value.first.return_value = MagicMock(spec=IterationWorkflow)
session = MagicMock(spec=Session)
session.query = mock_query
result = IterationWorkflow.find_or_create_by_name(session, 'workflow_name', 'description', False)
mock_add.assert_not_called()
assert result.__class__ == IterationWorkflow
@patch('sqlalchemy.orm.Session.query')
def test_fetch_trigger_step_id(mock_query):
mock_query.return_value.filter.return_value.first.return_value = MagicMock() # Assume we have a proper trigger step
session = MagicMock(spec=Session)
session.query = mock_query
result = IterationWorkflow.fetch_trigger_step_id(session, 1)
mock_query.assert_called_once() # The mock_query must be called once
assert result is not None
================================================
FILE: tests/unit_tests/models/test_iteration_workflow_step.py
================================================
import pytest
from unittest.mock import MagicMock
from sqlalchemy.orm import Session
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
@pytest.fixture
def mock_session():
session = MagicMock(spec=Session)
session.query.return_value.filter.return_value.first.return_value = MagicMock(spec=IterationWorkflowStep)
return session
def test_find_by_id(mock_session):
result = IterationWorkflowStep.find_by_id(mock_session, 1)
mock_session.query.assert_called_once_with(IterationWorkflowStep)
assert result.__class__ == IterationWorkflowStep
def test_find_or_create_step_new(mock_session):
mock_session.query.return_value.filter.return_value.first.return_value = None
result = IterationWorkflowStep.find_or_create_step(mock_session, 1, 'unique_id', 'prompt', 'variables', 'step_type',
'output_type')
mock_session.add.assert_called_once()
assert result.__class__ == IterationWorkflowStep
def test_find_or_create_step_exists(mock_session):
result = IterationWorkflowStep.find_or_create_step(mock_session, 1, 'unique_id', 'prompt', 'variables', 'step_type',
'output_type')
mock_session.add.assert_not_called()
assert result.__class__ == IterationWorkflowStep
================================================
FILE: tests/unit_tests/models/test_knowledge_configs.py
================================================
import unittest
from unittest.mock import Mock, patch, MagicMock
from sqlalchemy.orm.session import Session
from superagi.models.knowledge_configs import KnowledgeConfigs
class TestKnowledgeConfigs(unittest.TestCase):
def setUp(self):
self.session = Mock(spec=Session)
self.knowledge_id = 1
self.test_configs = {'key1': 'value1', 'key2': 'value2'}
@patch('requests.get')
def test_fetch_knowledge_config_details_marketplace(self, mock_get):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = [{'key': 'key1', 'value': 'value1'}, {'key': 'key2', 'value': 'value2'}]
mock_get.return_value = mock_response
configs = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(self.knowledge_id)
self.assertEqual(configs, self.test_configs)
def test_add_update_knowledge_config(self):
KnowledgeConfigs.add_update_knowledge_config(self.session, self.knowledge_id, self.test_configs)
self.session.add.assert_called()
self.session.commit.assert_called()
def test_get_knowledge_config_from_knowledge_id(self):
test_obj = Mock()
test_obj.key = "key1"
test_obj.value = "value1"
self.session.query.return_value.filter.return_value.all.return_value = [test_obj]
configs = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(self.session, self.knowledge_id)
self.assertEqual(configs, {"key1": "value1"})
def test_delete_knowledge_config(self):
KnowledgeConfigs.delete_knowledge_config(self.session, self.knowledge_id)
self.session.query.assert_called()
self.session.commit.assert_called()
def tearDown(self):
pass
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/unit_tests/models/test_marketplace_stats.py
================================================
import unittest
from unittest.mock import patch, MagicMock
from sqlalchemy.orm import Session
from superagi.models.marketplace_stats import MarketPlaceStats
class TestMarketPlaceStats(unittest.TestCase):
@patch('requests.get')
def test_get_knowledge_installation_number(self, mock_get):
test_json = {'download_count':123}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = test_json
mock_get.return_value = mock_response
result = MarketPlaceStats.get_knowledge_installation_number(1)
self.assertEqual(result, test_json)
@patch('requests.get')
def test_get_knowledge_installation_number_status_not_200(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 404
mock_get.return_value = mock_response
result = MarketPlaceStats.get_knowledge_installation_number(1)
self.assertEqual(result, [])
@patch('sqlalchemy.orm.Session')
def test_update_knowledge_install_number_existing(self, mock_session):
instance = MagicMock()
instance.value = '5'
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = instance
mock_session.query.return_value = mock_query
MarketPlaceStats.update_knowledge_install_number(mock_session, 1, 10)
self.assertEqual(instance.value, "10")
mock_query.filter.assert_called()
mock_session.commit.assert_called()
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/models/test_models.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.models.models import Models
@pytest.fixture
def mock_session():
return MagicMock()
def test_create_model(mock_session):
# Arrange
model_name = "example_model"
end_point = "example_end_point"
model_provider_id = 1
token_limit = 500
model_type = "example_type"
version = "v1.0"
org_id = 1
model_features = "example_model_feature"
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
model = Models(model_name=model_name, end_point=end_point,
model_provider_id=model_provider_id, token_limit=token_limit,
type=model_type, version=version, org_id=org_id, model_features=model_features)
mock_session.add(model)
# Assert
mock_session.add.assert_called_once_with(model)
def test_repr_method_models(mock_session):
# Arrange
model_name = "example_model"
end_point = "example_end_point"
model_provider_id = 1
token_limit = 500
model_type = "example_type"
version = "v1.0"
org_id = 1
model_features = "example_model_feature"
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
model = Models(model_name=model_name, end_point=end_point,
model_provider_id=model_provider_id, token_limit=token_limit,
type=model_type, version=version, org_id=org_id, model_features=model_features)
model_repr = repr(model)
# Assert
assert model_repr == f"Models(id=None, model_name={model_name}, " \
f"end_point={end_point}, model_provider_id={model_provider_id}, " \
f"token_limit={token_limit}, " \
f"type={model_type}, " \
f"version={version}, " \
f"org_id={org_id}, " \
f"model_features={model_features})"
@patch('requests.get')
def test_fetch_marketplace_list(mock_get):
# Specify the return value of the get method
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = ['model1', 'model2']
mock_get.return_value = mock_response
# Call the method
result = Models.fetch_marketplace_list(1)
# Verify the result
assert result == ['model1', 'model2']
# @patch('superagi.models.models_config.ModelsConfig')
# @patch('logging.error')
# def test_get_model_install_details(mock_logging_error, mock_models_config, mock_session):
# mock_model = MagicMock()
# mock_model.model_name = 'model1'
# mock_model.model_provider_id = 1
#
# mock_marketplace_models = [{'model_name': 'model1', 'model_provider_id': 1}, {'model_name': 'model2', 'model_provider_id': 2}]
# mock_session.query.return_value.filter.return_value.all.return_value = [mock_model]
# mock_session.query.return_value.group_by.return_value.all.return_value = [('model1', 1)]
# mock_config = MagicMock()
# mock_config.provider = 'provider1'
#
# def determine_provider(*args):
# for arg in args:
# # Check if mock_config can be returned
# if isinstance(arg, int) and arg == 1:
# return mock_config
# # Return None for all other situations
# return None
#
# mock_session.query.return_value.filter.return_value.first.side_effect = determine_provider
#
# # Call the method
# result = Models.get_model_install_details(mock_session, mock_marketplace_models, MagicMock())
#
# # Verify the result
# expected_result = [
# {"model_name": "model1", "is_installed": True, "installs": 1, "provider": "provider1", "model_provider_id": 1},
# {"model_name": "model2", "is_installed": False, "installs": 0, "provider": None, "model_provider_id": 2}
# ]
# assert result == expected_result
# # Assert that logging.error has been called once when provider is None
# mock_logging_error.assert_called_once()
def test_fetch_model_tokens(mock_session):
# Specify the return value of the query
mock_session.query.return_value.filter.return_value.all.return_value = [('model1', 500)]
# Call the method
result = Models.fetch_model_tokens(mock_session, 1)
# Verify the result
assert result == {'model1': 500}
def test_store_model_details_when_model_exists(mock_session):
# Arrange
mock_session.query.return_value.filter.return_value.first.return_value = MagicMock()
mock_session.add = MagicMock()
# Act
response = Models.store_model_details(
mock_session,
organisation_id=1,
model_name="example_model",
description="description",
end_point="end_point",
model_provider_id=1,
token_limit=500,
type="type",
version="v1.0",
context_length=4096
)
# Assert
assert response == {"error": "Model Name already exists"}
def test_store_model_details_when_model_not_exists(mock_session, monkeypatch):
# Arrange
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_query = MagicMock()
mock_fetch_model_by_id = MagicMock()
# Patching the fetch_model_by_id method in the class
monkeypatch.setattr('superagi.models.models_config.ModelsConfig.fetch_model_by_id', mock_fetch_model_by_id)
mock_fetch_model_by_id.return_value = {"provider": "some_provider"}
# Act
response = Models.store_model_details(
mock_session,
organisation_id=1,
model_name="example_model",
description="description",
end_point="end_point",
model_provider_id=1,
token_limit=500,
type="type",
version="v1.0",
context_length=4096
)
# Assert
assert response == {"success": "Model Details stored successfully", "model_id": None}
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_store_model_details_when_unexpected_error_occurs(mock_session, monkeypatch):
# Arrange
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.add = MagicMock(side_effect=Exception("Unknown error"))
mock_fetch_model_by_id = MagicMock()
monkeypatch.setattr('superagi.models.models_config.ModelsConfig.fetch_model_by_id', mock_fetch_model_by_id)
mock_fetch_model_by_id.return_value = {"provider": "some_provider"}
# Act
response = Models.store_model_details(
mock_session,
organisation_id=1,
model_name="example_model",
description="description",
end_point="end_point",
model_provider_id=1,
token_limit=500,
type="type",
version="v1.0",
context_length=4096
)
# Assert
assert response == {"error": "Unexpected Error Occured"}
@patch('superagi.models.models_config.ModelsConfig')
def test_fetch_models(mock_models_config, mock_session):
# Specify the return value of the query
mock_session.query.return_value.join.return_value.filter.return_value.all.return_value = [
(1, "example_model", "description", "example_provider")
]
# Call the method
result = Models.fetch_models(mock_session, 1)
# Verify the result
assert result == [{
"id": 1,
"name": "example_model",
"description": "description",
"model_provider": "example_provider"
}]
@patch('superagi.models.models_config.ModelsConfig')
def test_fetch_model_details(mock_models_config, mock_session):
# Specify the return values for the query
mock_session.query.return_value.join.return_value.filter.return_value.first.return_value = (
1, "example_model", "description", "end_point", 100, "type1", "example_provider"
)
# Call the method
result = Models.fetch_model_details(mock_session, 1, 1)
# Verify the result
assert result == {
"id": 1,
"name": "example_model",
"description": "description",
"end_point": "end_point",
"token_limit": 100,
"type": "type1",
"model_provider": "example_provider"
}
================================================
FILE: tests/unit_tests/models/test_models_config.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.models.models_config import ModelsConfig
@pytest.fixture
def mock_session():
return MagicMock()
def test_create_models_config(mock_session):
# Arrange
provider = "example_provider"
api_key = "example_api_key"
org_id = 1
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
model_config = ModelsConfig(provider=provider, api_key=api_key, org_id=org_id)
mock_session.add(model_config)
# Assert
mock_session.add.assert_called_once_with(model_config)
def test_repr_method_models_config(mock_session):
# Arrange
provider = "example_provider"
api_key = "example_api_key"
org_id = 1
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
model_config = ModelsConfig(provider=provider, api_key=api_key, org_id=org_id)
model_config_repr = repr(model_config)
# Assert
assert model_config_repr == f"ModelsConfig(id=None, provider={provider}, " \
f"org_id={org_id})"
# @patch('superagi.helper.encyption_helper.decrypt_data', return_value='decrypted_api_key')
# @patch('superagi.helper.encyption_helper.encrypt_data', return_value='encrypted_api_key')
# def test_store_api_key(mock_encrypt_data, mock_decrypt_data, mock_session):
# # Arrange
# organisation_id = 1
# model_provider = "example_provider"
# model_api_key = "example_api_key"
#
# # Mock existing entry
# mock_existing_entry = MagicMock()
# mock_session.query.return_value.filter.return_value.first.return_value = mock_existing_entry
# # Call the method
# response = ModelsConfig.store_api_key(mock_session, organisation_id, model_provider, model_api_key)
#
# # Assert
# mock_existing_entry.api_key = 'encrypted_api_key'
# mock_session.add.assert_called_once_with(mock_existing_entry)
# mock_session.commit.assert_called_once()
# assert response == {'message': 'The API key was successfully stored'}
#
# # Mock new entry
# mock_session.query.return_value.filter.return_value.first.return_value = None
# # Call the method
# response = ModelsConfig.store_api_key(mock_session, organisation_id, model_provider, model_api_key)
#
# # Assert
# # The new_entry is local to the store_api_key method, we cannot directly assert its properties.
# # But we can check if a new entry is added.
# mock_session.add.assert_called()
# mock_session.commit.assert_called()
# assert response == {'message': 'The API key was successfully stored'}
# @patch('superagi.helper.encyption_helper.decrypt_data', return_value='decrypted_api_key')
# def test_fetch_api_keys(mock_decrypt_data, mock_session):
# # Arrange
# organisation_id = 1
# # Mock api_key_info
# mock_session.query.return_value.filter.return_value.all.return_value = [("example_provider", "encrypted_api_key")]
#
# # Call the method
# api_keys = ModelsConfig.fetch_api_keys(mock_session, organisation_id)
#
# # Assert
# assert api_keys == [{"provider": "example_provider", "api_key": "decrypted_api_key"}]
#
# @patch('superagi.helper.encyption_helper.decrypt_data', return_value='decrypted_api_key')
# def test_fetch_api_key(mock_session):
# # Arrange
# organisation_id = 1
# model_provider = "example_provider"
# # Mock api_key_data
# mock_api_key_data = MagicMock()
# mock_api_key_data.id = 1
# mock_api_key_data.provider = "provider"
# mock_api_key_data.api_key = "encrypted_api_key"
# mock_session.query.return_value.filter.return_value.first.return_value = mock_api_key_data
#
# # Call the method
# api_key = ModelsConfig.fetch_api_key(mock_session, organisation_id, model_provider)
#
# # Assert
# assert api_key == [{'id': 1, 'provider': "provider", 'api_key': "encrypted_api_key"}]
def test_fetch_model_by_id(mock_session):
# Arrange
organisation_id = 1
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model
# Call the method
model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id)
assert model == {"provider": "some_provider"}
def test_fetch_model_by_id_marketplace(mock_session):
# Arrange
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model
# Call the method
model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id)
assert model == {"provider": "some_provider"}
================================================
FILE: tests/unit_tests/models/test_project.py
================================================
from unittest.mock import create_autospec
from sqlalchemy.orm import Session
from superagi.models.project import Project
def test_find_by_org_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample org ID
org_id = 123
# Create a mock project object to be returned by the session query
mock_project = Project(id=1, name="Test Project", organisation_id=org_id, description="Project for testing")
# Configure the session query to return the mock project
session.query.return_value.filter.return_value.first.return_value = mock_project
# Call the method under test
project = Project.find_by_org_id(session, org_id)
# Assert that the returned project object matches the mock project
assert project == mock_project
def test_find_by_id():
# Create a mock session
session = create_autospec(Session)
# Create a sample project ID
project_id = 123
# Create a mock project object to be returned by the session query
mock_project = Project(id=project_id, name="Test Project", organisation_id=1, description="Project for testing")
# Configure the session query to return the mock project
session.query.return_value.filter.return_value.first.return_value = mock_project
# Call the method under test
project = Project.find_by_id(session, project_id)
# Assert that the returned project object matches the mock project
assert project == mock_project
================================================
FILE: tests/unit_tests/models/test_tool.py
================================================
import pytest
from unittest.mock import MagicMock, call
from sqlalchemy.orm.exc import NoResultFound
from superagi.models.tool import Tool
from superagi.controllers.types.agent_with_config import AgentConfigInput
from fastapi import HTTPException
from typing import List
@pytest.fixture
def mock_session():
session = MagicMock()
get_mock = MagicMock()
get_mock.side_effect = [MagicMock(), NoResultFound()] # assuming 2nd tool won't be found
session.query.return_value.get = get_mock
return session
def test_get_invalid_tools(mock_session):
# Set up the mock session such that the second tool is not found
mock_session.query.return_value.get.side_effect = [MagicMock(), None]
# Call the get_invalid_tools method with tool_ids as [1, 2]
invalid_tool_ids = Tool.get_invalid_tools([1, 2], mock_session)
# Assert that the returned invalid tool IDs is as expected
assert invalid_tool_ids == [2]
# Assert that mock_session.query().get() was called with the correct arguments
calls = [call(Tool).get(1), call(Tool).get(2)]
mock_session.query.assert_has_calls(calls, any_order=True)
================================================
FILE: tests/unit_tests/models/test_tool_config.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
@pytest.fixture
def mock_session():
return MagicMock()
def test_add_or_update_existing_tool_config(mock_session):
# Arrange
toolkit_id = 1
key = "example_key"
value = "example_value"
existing_tool_config = ToolConfig(toolkit_id=toolkit_id, key=key, value="old_value")
mock_session.query.return_value.filter_by.return_value.first.return_value = existing_tool_config
# Act
ToolConfig.add_or_update(mock_session, toolkit_id, key, value)
# Assert
assert existing_tool_config.value == value
mock_session.commit.assert_called_once()
def test_add_or_update_new_tool_config(mock_session):
# Arrange
toolkit_id = 1
key = "example_key"
value = "example_value"
mock_session.query.return_value.filter_by.return_value.first.return_value = None
# Act
ToolConfig.add_or_update(mock_session, toolkit_id, key, value)
# Assert
# mock_session.add.assert_called_once_with(ToolConfig(toolkit_id=toolkit_id, key=key, value=value))
mock_session.commit.assert_called_once()
================================================
FILE: tests/unit_tests/models/test_toolkit.py
================================================
from unittest.mock import MagicMock, patch, call,create_autospec,Mock
import pytest
from superagi.models.organisation import Organisation
from superagi.models.toolkit import Toolkit
from superagi.models.tool import Tool
from sqlalchemy.orm import Session
@pytest.fixture
def mock_session():
return MagicMock()
# Mocked tool
@pytest.fixture
def mock_tool():
tool = MagicMock(spec=Tool)
tool.id = 1
return tool
# Mocked session
@pytest.fixture
def mock_session(mock_tool):
session = MagicMock()
query = session.query
query.return_value.filter.return_value.all.return_value = [mock_tool]
query.return_value.filter.return_value.first.return_value = mock_tool
return session
# marketplace_url = "http://localhost:8001"
marketplace_url = "https://app.superagi.com/api"
def test_add_or_update_existing_toolkit(mock_session):
# Arrange
name = "example_toolkit"
description = "Example toolkit description"
show_toolkit = True
organisation_id = 1
tool_code_link = "https://example.com/toolkit"
existing_toolkit = Toolkit(
name=name,
description="Old description",
show_toolkit=False,
organisation_id=organisation_id,
tool_code_link="https://old-link.com"
)
mock_session.query.return_value.filter.return_value.first.return_value = existing_toolkit
# Act
result = Toolkit.add_or_update(mock_session, name, description, show_toolkit, organisation_id, tool_code_link)
# Assert
assert result == existing_toolkit
assert result.name == name
assert result.description == description
assert result.show_toolkit == show_toolkit
assert result.organisation_id == organisation_id
assert result.tool_code_link == tool_code_link
mock_session.add.assert_not_called() # Make sure add was not called
mock_session.commit.assert_called_once()
mock_session.flush.assert_called_once()
def test_add_or_update_new_toolkit(mock_session):
# Arrange
name = "example_toolkit"
description = "Example toolkit description"
show_toolkit = True
organisation_id = 1
tool_code_link = "https://example.com/toolkit"
mock_session.query.return_value.filter.return_value.first.return_value = None
# Act
result = Toolkit.add_or_update(mock_session, name, description, show_toolkit, organisation_id, tool_code_link)
# Assert
assert isinstance(result, Toolkit)
assert result.name == name
assert result.description == description
assert result.show_toolkit == show_toolkit
assert result.organisation_id == organisation_id
assert result.tool_code_link == tool_code_link
mock_session.add.assert_called_once_with(result)
mock_session.commit.assert_called_once()
mock_session.flush.assert_called_once()
def test_fetch_marketplace_list_success():
# Arrange
page = 1
expected_response = [
{
"id": 1,
"name": "ToolKit 1",
"description": "Description 1"
},
{
"id": 2,
"name": "ToolKit 2",
"description": "Description 2"
}
]
# Mock the requests.get method
with patch('requests.get') as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = expected_response
# Act
result = Toolkit.fetch_marketplace_list(page)
# Assert
assert result == expected_response
mock_get.assert_called_once_with(
f"{marketplace_url}/toolkits/marketplace/list/{str(page)}",
headers={'Content-Type': 'application/json'},
timeout=10
)
def test_fetch_marketplace_detail_success():
# Arrange
search_str = "search string"
toolkit_name = "tool kit name"
expected_response = {
"id": 1,
"name": "ToolKit 1",
"description": "Description 1"
}
# Mock the requests.get method
with patch('requests.get') as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = expected_response
# Act
result = Toolkit.fetch_marketplace_detail(search_str, toolkit_name)
# Assert
assert result == expected_response
mock_get.assert_called_once_with(
f"{marketplace_url}/toolkits/marketplace/{search_str.replace(' ', '%20')}/{toolkit_name.replace(' ', '%20')}",
headers={'Content-Type': 'application/json'},
timeout=10
)
def test_fetch_marketplace_detail_error():
# Arrange
search_str = "search string"
toolkit_name = "tool kit name"
# Mock the requests.get method to simulate an error response
with patch('requests.get') as mock_get:
mock_get.return_value.status_code = 500
# Act
result = Toolkit.fetch_marketplace_detail(search_str, toolkit_name)
# Assert
assert result is None
mock_get.assert_called_once_with(
f"{marketplace_url}/toolkits/marketplace/{search_str.replace(' ', '%20')}/{toolkit_name.replace(' ', '%20')}",
headers={'Content-Type': 'application/json'},
timeout=10
)
def test_get_toolkit_from_name_existing_toolkit(mock_session):
# Arrange
toolkit_name = "example_toolkit"
organisation = Organisation(id=1)
expected_toolkit = Toolkit(name=toolkit_name,organisation_id=organisation.id)
# Mock the session.query method
mock_session.query.return_value.filter_by.return_value.first.return_value = expected_toolkit
# Act
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name,organisation)
# Assert
assert result == expected_toolkit
mock_session.query.assert_called_once_with(Toolkit)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name,organisation_id=organisation.id)
mock_session.query.return_value.filter_by.return_value.first.assert_called_once()
def test_get_toolkit_from_name_nonexistent_toolkit(mock_session):
# Arrange
toolkit_name = "nonexistent_toolkit"
# Mock the session.query method to return None
mock_session.query.return_value.filter_by.return_value.first.return_value = None
organisation = Organisation(id=1)
# Act
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name,organisation)
# Assert
assert result is None
mock_session.query.assert_called_once_with(Toolkit)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name,organisation_id=organisation.id)
mock_session.query.return_value.filter_by.return_value.first.assert_called_once()
def test_get_toolkit_installed_details(mock_session):
# Arrange
marketplace_toolkits = [
{"name": "Toolkit 1"},
{"name": "Toolkit 2"},
{"name": "Toolkit 3"}
]
organisation = Organisation(id=1)
installed_toolkits = [
Toolkit(name="Toolkit 1"),
Toolkit(name="Toolkit 3")
]
mock_session.query.return_value.filter.return_value.all.return_value = installed_toolkits
# Act
result = Toolkit.get_toolkit_installed_details(mock_session, marketplace_toolkits, organisation)
# Assert
assert len(result) == 3
assert result[0]["name"] == "Toolkit 1"
assert result[0]["is_installed"] is True
assert result[1]["name"] == "Toolkit 2"
assert result[1]["is_installed"] is False
assert result[2]["name"] == "Toolkit 3"
assert result[2]["is_installed"] is True
mock_session.query.assert_called_once()
mock_session.query.return_value.filter.return_value.all.assert_called_once()
# Test function
def test_fetch_tool_ids_from_toolkit(mock_tool, mock_session):
# Arranging
toolkit_ids = [1, 2, 3]
# Act
result = Toolkit.fetch_tool_ids_from_toolkit(mock_session, toolkit_ids)
# Assert
assert result == [mock_tool.id for _ in toolkit_ids]
def test_get_tool_and_toolkit_arr_with_nonexistent_toolkit():
# Create a mock session
session = create_autospec(Session)
# Configure the session query to return None for toolkit
session.query.return_value.filter.return_value.first.return_value = None
# Call the method under test with a non-existent toolkit
agent_config_tools_arr = [
{"name": "NonExistentToolkit", "tools": ["Tool1", "Tool2"]},
]
# Use a context manager to capture the raised exception and its message
with pytest.raises(Exception) as exc_info:
Toolkit.get_tool_and_toolkit_arr(session,1, agent_config_tools_arr)
# Assert that the expected error message is contained within the raised exception message
expected_error_message = "One or more of the Tool(s)/Toolkit(s) does not exist."
assert expected_error_message in str(exc_info.value)
================================================
FILE: tests/unit_tests/models/test_vector_db_configs.py
================================================
import unittest
from unittest.mock import Mock, patch
from superagi.models.vector_db_configs import VectordbConfigs
class TestVectordbConfigs(unittest.TestCase):
def setUp(self):
self.session_mock = Mock()
self.vector_db_id_mock = 1
self.db_creds_mock = {"key1": "value1", "key2": "value2"}
@patch('superagi.models.vector_db_configs.VectordbConfigs')
def test_get_vector_db_config_from_db_id(self, model_mock):
vectordb_mock = Mock()
vectordb_mock.key = "key1"
vectordb_mock.value = "value1"
self.session_mock.query().filter().all.return_value = [vectordb_mock]
result = VectordbConfigs.get_vector_db_config_from_db_id(self.session_mock, self.vector_db_id_mock)
self.assertEqual(result, {"key1": "value1"})
@patch('superagi.models.vector_db_configs.VectordbConfigs')
def test_add_vector_db_config(self, model_mock):
VectordbConfigs.add_vector_db_config(self.session_mock, self.vector_db_id_mock, self.db_creds_mock)
self.assertEqual(self.session_mock.add.call_count, len(self.db_creds_mock))
self.assertTrue(self.session_mock.commit.called)
@patch('superagi.models.vector_db_configs.VectordbConfigs')
def test_delete_vector_db_configs(self, model_mock):
VectordbConfigs.delete_vector_db_configs(self.session_mock, self.vector_db_id_mock)
self.assertTrue(self.session_mock.query(model_mock).filter(model_mock.vector_db_id == self.vector_db_id_mock).delete.called)
self.assertTrue(self.session_mock.commit.called)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/models/test_vector_db_indices.py
================================================
import unittest
from unittest.mock import Mock, MagicMock, call
from superagi.models.vector_db_indices import VectordbIndices
class TestVectordbIndices(unittest.TestCase):
def setUp(self):
self.mock_session = Mock()
self.query_mock = self.mock_session.query.return_value
self.filter_mock = self.query_mock.filter.return_value
def test_get_vector_index_from_id(self):
VectordbIndices.get_vector_index_from_id(self.mock_session, 1)
self.mock_session.query.assert_called_with(VectordbIndices)
self.filter_mock.first.assert_called_once()
def test_get_vector_indices_from_vectordb(self):
VectordbIndices.get_vector_indices_from_vectordb(self.mock_session, 1)
self.mock_session.query.assert_called_with(VectordbIndices)
self.filter_mock.all.assert_called_once()
def test_delete_vector_db_index(self):
VectordbIndices.delete_vector_db_index(self.mock_session, 1)
self.mock_session.query.assert_called_with(VectordbIndices)
self.filter_mock.delete.assert_called_once()
self.mock_session.commit.assert_called_once()
def test_add_vector_index(self):
VectordbIndices.add_vector_index(self.mock_session, 'test', 1, 100, 'active')
self.mock_session.add.assert_called_once()
self.mock_session.commit.assert_called_once()
def test_update_vector_index_state(self):
VectordbIndices.update_vector_index_state(self.mock_session, 1, 'inactive')
self.mock_session.query.assert_called_with(VectordbIndices)
self.filter_mock.first.assert_called_once()
self.mock_session.commit.assert_called_once()
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/models/test_vector_dbs.py
================================================
import unittest
from unittest.mock import Mock, patch
from superagi.models.vector_dbs import Vectordbs
class TestVectordbs(unittest.TestCase):
def setUp(self):
# Create a mock sql session
self.mock_session = Mock()
# Create an object of Vectordbs for testing
self.test_vector_db = Vectordbs(name='test_db', db_type='test_db_type', organisation_id=1)
@patch('requests.get')
def test_fetch_marketplace_list(self, mock_get):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = [{'name': 'test_db'}]
mock_get.return_value = mock_response
# Assert that fetch_marketplace_list() returns the correct value
self.assertListEqual(Vectordbs.fetch_marketplace_list(), [{'name': 'test_db'}])
def test_get_vector_db_from_id(self):
self.mock_session.query.return_value.filter.return_value.first.return_value = self.test_vector_db
returned_db = Vectordbs.get_vector_db_from_id(self.mock_session, 1)
# Assert that the returned db is the same as the set up test_vector_db
self.assertEqual(returned_db, self.test_vector_db)
def test_get_vector_db_from_organisation(self):
self.mock_session.query.return_value.filter.return_value.all.return_value = [self.test_vector_db]
returned_db_list = Vectordbs.get_vector_db_from_organisation(self.mock_session, Mock(id=1))
# Assert that returned list of dbs contains the test_vector_db
self.assertIn(self.test_vector_db, returned_db_list)
def test_add_vector_db(self):
# Assert that new db name matches the created db
new_db = Vectordbs.add_vector_db(self.mock_session, 'test_db', 'test_db_type', Mock(id=1))
self.assertEqual(new_db.name, 'test_db')
def test_delete_vector_db(self):
Vectordbs.delete_vector_db(self.mock_session, 1)
# Assert that the session's delete method was called with the correct arguments
self.mock_session.query.assert_called_once_with(Vectordbs)
self.mock_session.query.return_value.filter.return_value.delete.assert_called_once()
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/unit_tests/resource_manager/__init__.py
================================================
================================================
FILE: tests/unit_tests/resource_manager/test_file_manager.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.models.resource import Resource
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
from superagi.resource_manager.file_manager import FileManager
@pytest.fixture
def resource_manager():
session_mock = Mock()
resource_manager = FileManager(session_mock)
#resource_manager.agent_id = 1 # replace with actual value
return resource_manager
def test_write_binary_file(resource_manager):
with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \
patch.object(ResourceHelper, 'make_written_file_resource',
return_value=Resource(name='test.png', storage_type='S3')), \
patch.object(S3Helper, 'upload_file'), \
patch.object(logger, 'info') as logger_mock:
result = resource_manager.write_binary_file('test.png', b'data')
assert result == "Binary test.png saved successfully"
logger_mock.assert_called_once_with("Binary test.png saved successfully")
def test_write_file(resource_manager):
with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \
patch.object(ResourceHelper, 'make_written_file_resource',
return_value=Resource(name='test.txt', storage_type='S3')), \
patch.object(S3Helper, 'upload_file'), \
patch.object(logger, 'info') as logger_mock:
result = resource_manager.write_file('test.txt', 'content')
assert result == "test.txt - File written successfully"
logger_mock.assert_called_once_with("test.txt - File written successfully")
================================================
FILE: tests/unit_tests/resource_manager/test_llama_document_creation.py
================================================
import pytest
from unittest.mock import patch, MagicMock
from superagi.resource_manager.resource_manager import ResourceManager
def test_create_llama_document_s3(mocker):
agent_id = 'test_agent'
resource_manager = ResourceManager(agent_id)
mock_boto_client = MagicMock()
mock_s3_obj = {
'Body': MagicMock(read=MagicMock(return_value='mock_file_content'))
}
mock_boto_client.get_object.return_value = mock_s3_obj
mocker.patch('boto3.client', return_value=mock_boto_client)
mocker.patch('superagi.resource_manager.resource_manager.get_config',
side_effect=['mock_access_key', 'mock_secret_key', 'mock_bucket'])
mocker.patch('builtins.open', mocker.mock_open())
mocker.patch('os.remove')
MockSimpleDirectoryReader = MagicMock()
mocker.patch('superagi.resource_manager.resource_manager.SimpleDirectoryReader',
return_value=MockSimpleDirectoryReader)
resource_manager.create_llama_document_s3('mock_file_path')
mock_boto_client.get_object.assert_called_once_with(
Bucket='mock_bucket',
Key='mock_file_path')
MockSimpleDirectoryReader.load_data.assert_called_once()
def test_create_llama_document_s3_file_path_provided(mocker):
resource_manager = ResourceManager('test_agent')
mock_boto_client = MagicMock()
mocker.patch('boto3.client', return_value=mock_boto_client)
mocker.patch('superagi.resource_manager.resource_manager.get_config',
side_effect=['mock_access_key', 'mock_secret_key', 'mock_bucket'])
mocker.patch('builtins.open', mocker.mock_open())
mocker.patch('os.remove')
MockSimpleDirectoryReader = MagicMock()
mocker.patch('superagi.resource_manager.resource_manager.SimpleDirectoryReader',
return_value=MockSimpleDirectoryReader)
with pytest.raises(Exception, match="file_path must be provided"):
resource_manager.create_llama_document_s3(None)
================================================
FILE: tests/unit_tests/resource_manager/test_llama_vector_store_factory.py
================================================
import pytest
from unittest.mock import patch
from llama_index.vector_stores import PineconeVectorStore, RedisVectorStore
from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
from superagi.types.vector_store_types import VectorStoreType
def test_llama_vector_store_factory():
# Mocking method arguments
vector_store_name = VectorStoreType.PINECONE
index_name = "test_index_name"
factory = LlamaVectorStoreFactory(vector_store_name, index_name)
# Test case for VectorStoreType.PINECONE
with patch.object(PineconeVectorStore, "__init__", return_value=None):
vector_store = factory.get_vector_store()
assert isinstance(vector_store, PineconeVectorStore)
# Test case for VectorStoreType.REDIS
factory.vector_store_name = VectorStoreType.REDIS
with patch.object(RedisVectorStore, "__init__", return_value=None), \
patch('superagi.config.config.get_config', return_value=None):
vector_store = factory.get_vector_store()
assert isinstance(vector_store, RedisVectorStore)
# Test case for unknown VectorStoreType
factory.vector_store_name = "unknown"
with pytest.raises(ValueError) as exc_info:
factory.get_vector_store()
assert str(exc_info.value) == "unknown vector store is not supported yet."
================================================
FILE: tests/unit_tests/resource_manager/test_save_document_to_vector_store.py
================================================
from unittest.mock import patch, Mock
from llama_index import VectorStoreIndex, StorageContext, Document
from superagi.resource_manager.resource_manager import ResourceManager
from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
@patch.object(LlamaVectorStoreFactory, 'get_vector_store')
@patch.object(StorageContext, 'from_defaults')
@patch.object(VectorStoreIndex, 'from_documents')
def test_save_document_to_vector_store(mock_vc_from_docs, mock_sc_from_defaults, mock_get_vector_store):
# Prepare test resources
mock_vector_store = Mock()
mock_get_vector_store.return_value = mock_vector_store
mock_sc_from_defaults.return_value = "mock_storage_context"
mock_vc_from_docs.return_value = "mock_index"
resource_manager = ResourceManager("test_agent_id")
documents = [Document(text="doc1"), Document(text="doc2")]
resource_id = "test_resource_id"
# Run test method
resource_manager.save_document_to_vector_store(documents, resource_id, "test_model_api_key")
# Validate calls
mock_get_vector_store.assert_called_once()
mock_sc_from_defaults.assert_called_once_with(vector_store=mock_vector_store)
mock_vc_from_docs.assert_called_once_with(documents, storage_context="mock_storage_context")
# Add more assertions here if needed, e.g., to check side effects
mock_vector_store.persist.assert_called_once()
================================================
FILE: tests/unit_tests/test_migrations_multiheads.py
================================================
import glob
import os
import re
import pytest
from collections import Counter
def test_alembic_down_revision():
# Construct the path to the versions directory
versions_dir = os.path.join('.', 'migrations', 'versions')
# Get all .py files in versions directory
all_py_files = glob.glob(os.path.join(versions_dir, "*.py"))
# Regex pattern for finding down_revision lines in .py files
down_revision_pattern = re.compile(r'down_revision = \'(\w+)\'')
down_revisions = []
file_down_revisions = []
for file in all_py_files:
with open(file) as f:
content = f.read()
match = down_revision_pattern.search(content)
if match:
down_revisions.append(match.group(1))
file_down_revisions.append((file, match.group(1)))
counter = Counter(down_revisions)
duplicates = [item for item, count in counter.items() if count > 1]
# get the files that have duplicate down revisions
files_with_duplicates = [file for file, down_revision in file_down_revisions if down_revision in duplicates]
assert len(duplicates) == 0, f"Duplicate down revisions found in files: {files_with_duplicates} \n this is " \
f"caused because a newer migration might have been added after the migration " \
f"you added. Please fix this by changing the down_revision to the correct one."
================================================
FILE: tests/unit_tests/test_tool_manager.py
================================================
import json
import os
import shutil
import tempfile
from unittest.mock import Mock, patch
import pytest
from superagi.tool_manager import parse_github_url, download_tool, load_tools_config, download_and_extract_tools, \
update_tools_json
@pytest.fixture
def tools_json_path():
# Create a temporary directory and return the path to the tools.json file
with tempfile.TemporaryDirectory() as temp_dir:
yield os.path.join(temp_dir, "tools.json")
def test_parse_github_url():
url = 'https://github.com/owner/repo'
assert parse_github_url(url) == 'owner/repo/main'
def setup_function():
os.makedirs('target_folder', exist_ok=True)
# Teardown function to remove the directory
def teardown_function():
shutil.rmtree('target_folder')
@patch('requests.get')
@patch('zipfile.ZipFile')
def test_download_tool(mock_zip, mock_get):
mock_response = Mock()
mock_response.content = b'file content'
mock_get.return_value = mock_response
mock_zip.return_value.__enter__.return_value.namelist.return_value = ['owner-repo/somefile.txt']
download_tool('https://github.com/owner/repo', 'target_folder')
mock_get.assert_called_once_with('https://api.github.com/repos/owner/repo/zipball/main')
mock_zip.assert_called_once_with('target_folder/tool.zip', 'r')
@patch('json.load')
def test_load_tools_config(mock_json_load):
mock_json_load.return_value = {"tools": {"tool1": "url1", "tool2": "url2"}}
config = load_tools_config()
assert config == {"tool1": "url1", "tool2": "url2"}
@patch('superagi.tool_manager.download_tool')
@patch('superagi.tool_manager.load_tools_config')
def test_download_and_extract_tools(mock_load_tools_config, mock_download_tool):
mock_load_tools_config.return_value = {"tool1": "url1", "tool2": "url2"}
download_and_extract_tools()
mock_load_tools_config.assert_called_once()
mock_download_tool.assert_any_call('url1', os.path.join('superagi', 'tools', 'external_tools', 'tool1'))
mock_download_tool.assert_any_call('url2', os.path.join('superagi', 'tools', 'external_tools', 'tool2'))
def test_update_tools_json(tools_json_path):
# Create an initial tools.json file with some data
initial_data = {
"tools": {
"tool1": "link1",
"tool2": "link2"
}
}
with open(tools_json_path, "w") as file:
json.dump(initial_data, file)
# Define the folder links to be updated
folder_links = {
"tool3": "link3",
"tool4": "link4"
}
# Call the function to update the tools.json file
update_tools_json(tools_json_path, folder_links)
# Read the updated tools.json file
with open(tools_json_path, "r") as file:
updated_data = json.load(file)
# Assert that the data was updated correctly
expected_data = {
"tools": {
"tool1": "link1",
"tool2": "link2",
"tool3": "link3",
"tool4": "link4"
}
}
assert updated_data == expected_data
================================================
FILE: tests/unit_tests/tools/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/code/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/code/test_improve_code.py
================================================
import pytest
from unittest.mock import Mock, MagicMock
from superagi.tools.code.improve_code import ImproveCodeTool
@pytest.fixture
def mock_improve_code_tool():
improve_code_tool = ImproveCodeTool()
improve_code_tool.resource_manager = Mock()
improve_code_tool.llm = Mock()
return improve_code_tool
def test_execute(mock_improve_code_tool):
mock_improve_code_tool.resource_manager.get_files.return_value = ['test1', 'test2']
mock_improve_code_tool.resource_manager.read_file.return_value = "test file content"
mock_improve_code_tool.llm.chat_completion.return_value = {
"response":
{
"choices":
[
{
"message":
{
"content": "```\nimproved code\n```"
}
}
]
}
}
mock_improve_code_tool.resource_manager.write_file.return_value = "file saved successfully"
assert mock_improve_code_tool._execute() == "All codes improved and saved successfully in: test1 test2"
def test_execute_with_error(mock_improve_code_tool):
mock_improve_code_tool.resource_manager.get_files.return_value = ['test1']
mock_improve_code_tool.resource_manager.read_file.return_value = "test file content"
mock_improve_code_tool.llm.chat_completion.return_value = {
"response":
{
"choices":
[
{
"message":
{
"content": "```\nimproved code\n```"
}
}
]
}
}
mock_improve_code_tool.resource_manager.write_file.return_value = "Error: Could not save file"
assert mock_improve_code_tool._execute() == "Error: Could not save file"
================================================
FILE: tests/unit_tests/tools/code/test_write_code.py
================================================
from unittest.mock import Mock
import pytest
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.code.write_code import CodingTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from unittest.mock import MagicMock
class MockBaseLlm:
def chat_completion(self, messages, max_tokens):
return {"content": "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```"}
def get_model(self):
return "gpt-3.5-turbo"
class TestCodingTool:
@pytest.fixture
def tool(self):
tool = CodingTool()
tool.llm = MockBaseLlm()
tool.resource_manager = Mock(spec=FileManager)
tool.tool_response_manager = Mock(spec=ToolResponseQueryManager)
mock_session = MagicMock(name="session")
tool.toolkit_config.session = mock_session
return tool
def test_execute(self, tool):
tool.resource_manager.write_file.return_value = "File write successful"
tool.tool_response_manager.get_last_response.return_value = "Mocked Spec"
response = tool._execute("Test spec description")
assert response == "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```\n Codes generated and saved successfully in File1.py, File2.py"
tool.resource_manager.write_file.assert_any_call("README.md", 'File1.py\n')
tool.resource_manager.write_file.assert_any_call("File1.py", "print('Hello World')\n")
tool.resource_manager.write_file.assert_any_call("File2.py", "print('Hello again')\n")
tool.tool_response_manager.get_last_response.assert_called_once_with("WriteSpecTool")
================================================
FILE: tests/unit_tests/tools/code/test_write_spec.py
================================================
from unittest.mock import Mock
import pytest
from superagi.tools.code.write_spec import WriteSpecTool
from unittest.mock import MagicMock
class MockBaseLlm:
def chat_completion(self, messages, max_tokens):
return {"content": "Generated specification"}
def get_model(self):
return "gpt-3.5-turbo"
class TestWriteSpecTool:
@pytest.fixture
def tool(self):
tool = WriteSpecTool()
tool.llm = MockBaseLlm()
tool.resource_manager = Mock()
mock_session = MagicMock(name="session")
tool.toolkit_config.session = mock_session
return tool
def test_execute(self, tool):
tool.resource_manager.write_file = Mock()
tool.resource_manager.write_file.return_value = "File write successful"
response = tool._execute("Test task description", "test_spec_file.txt")
assert response == "Generated specification\nSpecification generated and saved successfully"
tool.resource_manager.write_file.assert_called_once_with("test_spec_file.txt", "Generated specification")
================================================
FILE: tests/unit_tests/tools/code/test_write_test.py
================================================
from unittest.mock import Mock, patch
from superagi.tools.code.write_test import WriteTestTool
from unittest.mock import MagicMock
def test_write_test_tool_init():
tool = WriteTestTool()
assert tool.llm is None
assert tool.agent_id is None
assert tool.name == "WriteTestTool"
assert tool.description is not None
assert tool.goals == []
assert tool.resource_manager is None
@patch('superagi.tools.code.write_test.PromptReader')
@patch('superagi.tools.code.write_test.AgentPromptBuilder')
@patch('superagi.tools.code.write_test.TokenCounter')
def test_execute(mock_token_counter, mock_agent_prompt_builder, mock_prompt_reader):
test_tool = WriteTestTool()
test_tool.tool_response_manager = Mock()
test_tool.resource_manager = Mock()
test_tool.llm = Mock()
mock_session = MagicMock(name="session")
test_tool.toolkit_config.session = mock_session
test_tool.tool_response_manager.get_last_response.return_value = 'WriteSpecTool response'
mock_prompt_reader.read_tools_prompt.return_value = 'Prompt template {goals} {test_description} {spec}'
mock_agent_prompt_builder.add_list_items_to_string.return_value = 'Goals string'
test_tool.llm.get_model.return_value = 'Model'
mock_token_counter.count_message_tokens.return_value = 100
mock_token_counter.token_limit.return_value = 1000
test_tool.llm.chat_completion.return_value = {
'content': 'File1\n```\nCode1```File2\n```\nCode2```',
}
test_tool.resource_manager.write_file.return_value = 'Success'
result = test_tool._execute('Test description', 'test_file.py')
assert 'File1' in result
assert 'Code1' in result
assert 'File2' in result
assert 'Code2' in result
assert 'Tests generated and saved successfully in test_file.py' in result
mock_prompt_reader.read_tools_prompt.assert_called_once()
mock_agent_prompt_builder.add_list_items_to_string.assert_called_once_with(test_tool.goals)
test_tool.tool_response_manager.get_last_response.assert_called()
test_tool.llm.get_model.assert_called()
mock_token_counter.count_message_tokens.assert_called()
mock_token_counter().token_limit.assert_called()
test_tool.llm.chat_completion.assert_called()
assert test_tool.resource_manager.write_file.call_count == 2
================================================
FILE: tests/unit_tests/tools/duck_duck_go/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/duck_duck_go/test_duckduckgo_results.py
================================================
import unittest
from unittest.mock import patch
import pytest
from superagi.tools.duck_duck_go.duck_duck_go_search import DuckDuckGoSearchTool
class TestDuckDuckGoSearchTool:
def setup_method(self):
self.your_obj = DuckDuckGoSearchTool() # Create an instance of DuckDuckGoSearchTool
def test_get_raw_duckduckgo_results_empty_query(self):
query = ""
expected_result = "[]"
result = self.your_obj.get_raw_duckduckgo_results(query)
assert result == expected_result
@patch('superagi.tools.duck_duck_go.duck_duck_go_search.DuckDuckGoSearchTool.get_raw_duckduckgo_results')
def test_get_raw_duckduckgo_results_valid_query(self, mock_get_raw_duckduckgo_results):
query = "python"
expected_result_length = 10
mock_results = ['result1', 'result2', 'result3', 'result4', 'result5',
'result6', 'result7', 'result8', 'result9', 'result10']
mock_get_raw_duckduckgo_results.return_value = mock_results
result = self.your_obj.get_raw_duckduckgo_results(query)
assert len(result) == expected_result_length
def test_get_formatted_webpages(self):
search_results = [
{"title": "Result 1", "href": "https://example.com/1"},
{"title": "Result 2", "href": "https://example.com/2"},
{"title": "Result 3", "href": "https://example.com/3"},
]
webpages = ["Webpage 1", "Webpage 2", "Webpage 3"]
expected_results = [
{"title": "Result 1", "body": "Webpage 1", "links": "https://example.com/1"},
{"title": "Result 2", "body": "Webpage 2", "links": "https://example.com/2"},
{"title": "Result 3", "body": "Webpage 3", "links": "https://example.com/3"},
]
results = self.your_obj.get_formatted_webpages(search_results, webpages)
assert results == expected_results
def test_get_content_from_url_with_empty_links(self):
links = []
expected_webpages = []
webpages = self.your_obj.get_content_from_url(links)
assert webpages == expected_webpages
def test_get_formatted_webpages_with_empty_webpages(self):
search_results = [
{"title": "Result 1", "href": "https://example.com/1"},
{"title": "Result 2", "href": "https://example.com/2"},
{"title": "Result 3", "href": "https://example.com/3"},
]
webpages = []
expected_results = []
results = self.your_obj.get_formatted_webpages(search_results, webpages)
assert results == expected_results
================================================
FILE: tests/unit_tests/tools/duck_duck_go/test_duckduckgo_toolkit.py
================================================
import pytest
from superagi.tools.duck_duck_go.duck_duck_go_search_toolkit import DuckDuckGoToolkit
from superagi.tools.duck_duck_go.duck_duck_go_search import DuckDuckGoSearchTool
class TestDuckDuckGoSearchToolKit:
def setup_method(self):
"""
Set up the test fixture.
This method is called before each test method is executed to prepare the test environment.
Returns:
None
"""
self.toolkit = DuckDuckGoToolkit()
def test_get_tools(self):
"""
Test the `get_tools` method of the `DuckDuckGoToolkit` class.
It should return a list of tools, containing one instance of `DuckDuckGoSearchTool`.
Returns:
None
"""
tools = self.toolkit.get_tools()
assert len(tools) == 1
assert isinstance(tools[0], DuckDuckGoSearchTool)
def test_get_env_keys(self):
"""
Test the `get_env_keys` method of the `DuckDuckGoToolkit` class.
It should return an empty list of environment keys.
Returns:
None
"""
env_keys = self.toolkit.get_env_keys()
assert len(env_keys) == 0
================================================
FILE: tests/unit_tests/tools/email/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/email/test_read_email.py
================================================
from unittest.mock import patch, Mock
from superagi.tools.email.read_email import ReadEmailTool
@patch('superagi.tools.email.read_email.ImapEmail')
@patch('superagi.tools.email.read_email.ReadEmail')
def test_execute(mock_read_email, mock_imap_email):
# Configure the mock objects
mock_conn = Mock()
mock_conn.select.return_value = ('OK', ['10']) # assume 10 messages in INBOX
mock_conn.fetch.return_value = ('OK', [(b'1 (RFC822 {337}', b'Some email content')])
mock_imap_email.return_value.imap_open.return_value = mock_conn
mock_read_email.return_value.obtain_header.return_value = ('From', 'To', 'Date', 'Subject')
mock_read_email.return_value.clean_email_body.return_value = 'Cleaned email body'
# Set up ReadEmailTool object
tool = ReadEmailTool()
tool.toolkit_config.get_tool_config = Mock()
tool.toolkit_config.get_tool_config.return_value = 'dummy_value'
# Call the function
result = tool._execute()
# Check the results
assert len(result) == 5 # check that one email was processed
assert result[0]['From'] == 'From'
assert result[0]['To'] == 'To'
assert result[0]['Date'] == 'Date'
assert result[0]['Subject'] == 'Subject'
assert result[0]['Message Body'] == 'Cleaned email body'
================================================
FILE: tests/unit_tests/tools/email/test_send_email.py
================================================
from unittest.mock import patch, MagicMock
from superagi.tools.email.send_email import SendEmailTool
def mock_get_tool_config(key):
configs = {
'EMAIL_ADDRESS': 'sender@example.com',
'EMAIL_PASSWORD': 'password',
'EMAIL_SIGNATURE': '',
'EMAIL_DRAFT_MODE': 'False',
'EMAIL_DRAFT_FOLDER': 'Drafts',
'EMAIL_IMAP_SERVER': 'imap.example.com',
'EMAIL_SMTP_HOST': 'host',
'EMAIL_SMTP_PORT': 'port',
}
return configs.get(key)
def mock_get_draft_tool_config(key):
configs = {
'EMAIL_ADDRESS': 'sender@example.com',
'EMAIL_PASSWORD': 'password',
'EMAIL_SIGNATURE': '',
'EMAIL_DRAFT_MODE': 'True',
'EMAIL_DRAFT_FOLDER': 'Drafts',
'EMAIL_IMAP_SERVER': 'imap.example.com',
'EMAIL_SMTP_HOST': 'host',
'EMAIL_SMTP_PORT': 'port',
}
return configs.get(key)
@patch('smtplib.SMTP')
@patch('superagi.helper.imap_email.ImapEmail.imap_open')
def test_execute_sends_email(mock_imap_open, mock_smtp):
# Given
send_email_tool = SendEmailTool()
mock_resp = MagicMock()
mock_resp.raise_for_status.return_value = None
mock_resp.json.return_value = 'data'
send_email_tool.toolkit_config.get_tool_config = mock_get_tool_config
# When
result = send_email_tool._execute('receiver@example.com', 'test subject', 'test body')
# Then
assert result == 'Email was sent to receiver@example.com'
mock_smtp.assert_called_once_with('host', 'port')
@patch('smtplib.SMTP')
@patch('superagi.helper.imap_email.ImapEmail.imap_open')
def test_execute_sends_email_to_draft(mock_imap_open, mock_smtp):
send_email_tool = SendEmailTool()
send_email_tool.toolkit_config.get_tool_config = mock_get_draft_tool_config
result = send_email_tool._execute('receiver@example.com', 'test subject', 'test body')
assert result == 'Email went to Drafts'
mock_imap_open.assert_called_once_with('Drafts', 'sender@example.com', 'password', 'imap.example.com')
mock_imap_instance = mock_imap_open.return_value
mock_imap_instance.append.assert_called_once()
mock_smtp.assert_not_called()
================================================
FILE: tests/unit_tests/tools/email/test_send_email_attachment.py
================================================
# import unittest
# from unittest.mock import patch, MagicMock, ANY
# from superagi.models.agent import Agent
# import os
# from superagi.tools.email.send_email_attachment import SendEmailAttachmentTool, SendEmailAttachmentInput
# import tempfile
# class TestSendEmailAttachmentTool(unittest.TestCase):
# # Create a new class-level test file
# testFile = tempfile.NamedTemporaryFile(delete=True)
# @patch("superagi.models.agent.Agent.get_agent_from_id")
# @patch("superagi.tools.email.send_email_attachment.SendEmailAttachmentTool.send_email_with_attachment")
# @patch("superagi.helper.resource_helper.ResourceHelper.get_agent_read_resource_path")
# @patch("superagi.helper.resource_helper.ResourceHelper.get_root_input_dir")
# @patch("os.path.exists", return_value=os.path.exists(testFile.name))
# @patch("superagi.helper.s3_helper.S3Helper.read_binary_from_s3")
# def test__execute(self, mock_s3_file_read, mock_exists, mock_get_root_input_dir, mock_get_agent_resource_path,
# mock_send_email_with_attachment, mock_get_agent_from_id):
# # Arrange
# tool = SendEmailAttachmentTool()
# tool.agent_id = 1
# mock_get_agent_resource_path.return_value = self.testFile.name
# mock_get_root_input_dir.return_value = "/root_dir/"
# mock_send_email_with_attachment.return_value = "Email sent"
# expected_result = "Email sent"
# mock_get_agent_from_id.return_value = Agent(id=1, name='Test Agent')
# tool.agent_execution_id = 1
# tool.toolkit_config.session = MagicMock()
# mock_s3_file_read.return_value = b"file contents"
# # Act
# result = tool._execute("test@example.com", "test subject", "test body", "test.txt")
# # Assert
# self.assertEqual(result, expected_result)
# mock_send_email_with_attachment.assert_called_once_with("test@example.com", "test subject", "test body", ANY)
# mock_s3_file_read.assert_called_once_with(self.testFile.name)
# if __name__ == "__main__":
# unittest.main()
================================================
FILE: tests/unit_tests/tools/file/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/file/test_list_files.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.tools.file.list_files import ListFileTool
@pytest.fixture
def list_file_tool():
list_file_tool = ListFileTool()
list_file_tool.agent_id = 1 # Set a dummy agent ID for testing.
# list_file_tool = "test_agent"
mock_session = MagicMock(name="session")
list_file_tool.toolkit_config.session = mock_session
yield list_file_tool
def test_list_files(list_file_tool):
with patch('os.walk') as mock_walk:
mock_walk.return_value = [
('/path/to', ('subdir',), ('file1.txt', '.file2.txt')),
('/path/to/subdir', (), ('file3.txt', 'file4.txt'))
]
files = list_file_tool.list_files('/path/to')
assert files == ['file1.txt', 'file3.txt', 'file4.txt']
def test_execute(list_file_tool):
mock_get_formatted_agent_level_path = MagicMock(return_value="SuperAGI/workspace/input/{agent_id}/")
with patch.object(ListFileTool, 'list_files', return_value=['file1.txt', 'file2.txt']), \
patch('superagi.helper.resource_helper.ResourceHelper.get_formatted_agent_level_path', new=mock_get_formatted_agent_level_path):
files = list_file_tool._execute()
assert files == ['file1.txt', 'file2.txt']
================================================
FILE: tests/unit_tests/tools/file/test_read_file.py
================================================
import os
import pytest
import tempfile
from unittest.mock import MagicMock, patch
from superagi.tools.file.read_file import ReadFileTool
from superagi.models.agent_execution import AgentExecution
from superagi.tools.file.read_file import ReadFileTool
from superagi.models.agent import Agent
@pytest.fixture
def mock_os_path_exists():
with patch("os.path.exists") as mock_exists:
yield mock_exists
@pytest.fixture
def mock_os_makedirs():
with patch("os.makedirs") as mock_makedirs:
yield mock_makedirs
@pytest.fixture
def mock_get_config():
with patch("superagi.config.config.get_config") as mock_get_config:
yield mock_get_config
@pytest.fixture
def read_file_tool():
read_file_tool = ReadFileTool()
read_file_tool.agent_id = 1 # Set a dummy agent ID for testing.
@pytest.fixture
def mock_s3_helper():
with patch("superagi.helper.s3_helper.S3Helper") as mock_s3_helper:
yield mock_s3_helper
@pytest.fixture
def mock_partition():
with patch("unstructured.partition.auto.partition") as mock_partition:
yield mock_partition
@pytest.fixture
def mock_get_agent_from_id():
with patch("superagi.models.agent.Agent.get_agent_from_id") as mock_get_agent:
yield mock_get_agent
@pytest.fixture
def mock_get_agent_execution_from_id():
with patch("superagi.models.agent_execution.AgentExecution.get_agent_execution_from_id") as mock_execution:
yield mock_execution
@pytest.fixture
def mock_resource_helper():
with patch("superagi.helper.resource_helper.ResourceHelper.get_agent_read_resource_path") as mock_resource_helper:
yield mock_resource_helper
def test_read_file_tool(mock_os_path_exists, mock_os_makedirs, mock_get_config, mock_s3_helper, mock_partition,
mock_get_agent_from_id, mock_get_agent_execution_from_id, mock_resource_helper):
mock_os_path_exists.return_value = True
mock_partition.return_value = ["This is a file.", "This is the second line."]
mock_get_config.return_value = "FILE"
mock_get_agent_from_id.return_value = MagicMock()
mock_get_agent_execution_from_id.return_value = MagicMock()
tool = ReadFileTool()
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.txt') as tmp:
tmp.write("This is a file.\nThis is the second line.")
tmp.seek(0) # Reset file pointer to the beginning
tmp.close() # Explicitly close the file
mock_resource_helper.return_value = tmp.name
try:
result = tool._execute(tmp.name)
assert isinstance(result, str)
assert "This is a file." in result
assert "This is the second line." in result
finally:
os.remove(tmp.name) # Ensure the temporary file is deleted
def test_read_file_tool_s3(mock_os_path_exists, mock_os_makedirs, mock_get_config, mock_s3_helper, mock_partition,
mock_get_agent_from_id, mock_get_agent_execution_from_id, mock_resource_helper):
mock_os_path_exists.return_value = True
mock_get_config.return_value = "S3" # ensure this function returns "S3"
mock_get_agent_from_id.return_value = MagicMock()
mock_get_agent_execution_from_id.return_value = MagicMock()
tool = ReadFileTool()
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.txt') as tmp:
tmp.write("This is a file.\nThis is the second line.")
tmp.seek(0) # Reset file pointer to the beginning
tmp.close() # Explicitly close the file
mock_resource_helper.return_value = tmp.name
mock_s3_helper.return_value.read_from_s3.return_value = open(tmp.name, 'r').read()
try:
result = tool._execute(tmp.name)
assert isinstance(result, str)
assert "This is a file." in result
assert "This is the second line." in result
finally:
os.remove(tmp.name) # Ensure the temporary file is deleted
def test_read_file_tool_not_found(mock_os_path_exists, mock_os_makedirs, mock_get_config, mock_s3_helper, mock_partition,
mock_get_agent_from_id, mock_get_agent_execution_from_id, mock_resource_helper):
mock_os_path_exists.return_value = False
mock_get_agent_from_id.return_value = MagicMock()
mock_get_agent_execution_from_id.return_value = MagicMock()
tool = ReadFileTool()
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.txt') as tmp:
tmp.write("This is a file.\nThis is the second line.")
tmp.seek(0) # Reset file pointer to the beginning
tmp.close() # Explicitly close the file
try:
with pytest.raises(FileNotFoundError):
tool._execute(tmp.name)
finally:
os.remove(tmp.name) # Ensure the temporary file is deleted
================================================
FILE: tests/unit_tests/tools/github/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/github/test_add_file.py
================================================
import pytest
from unittest.mock import MagicMock, patch
from superagi.helper.github_helper import GithubHelper
from superagi.tools.github.add_file import GithubAddFileTool, GithubAddFileSchema
def test_github_add_file_schema():
schema = GithubAddFileSchema(
repository_name="test_repo",
base_branch="main",
file_name="test_file",
folder_path="test_folder",
commit_message="test_commit",
repository_owner="test_owner"
)
assert schema.repository_name == "test_repo"
assert schema.base_branch == "main"
assert schema.file_name == "test_file"
assert schema.folder_path == "test_folder"
assert schema.commit_message == "test_commit"
assert schema.repository_owner == "test_owner"
@pytest.fixture
def github_add_file_tool():
return GithubAddFileTool()
@patch.object(GithubHelper, "make_fork")
@patch.object(GithubHelper, "create_branch")
@patch.object(GithubHelper, "add_file")
@patch.object(GithubHelper, "create_pull_request")
def test_github_add_file_tool_execute(mock_make_fork, mock_create_branch, mock_add_file, mock_create_pull_request, github_add_file_tool):
github_add_file_tool.toolkit_config.get_tool_config = MagicMock(side_effect=["test_token", "test_username"])
mock_make_fork.return_value = 201
mock_create_branch.return_value = 201
mock_add_file.return_value = 201
mock_create_pull_request.return_value = 201
response = github_add_file_tool._execute(
repository_name="test_repo",
base_branch="main",
commit_message="test_commit",
repository_owner="test_owner",
file_name="test_file",
folder_path="test_folder"
)
assert response == "Pull request to add file/folder has been created"
mock_make_fork.return_value = 422
mock_create_branch.return_value = 422
mock_add_file.return_value = 422
mock_create_pull_request.return_value = 422
response = github_add_file_tool._execute(
repository_name="test_repo",
base_branch="main",
commit_message="test_commit",
repository_owner="test_owner",
file_name="test_file",
folder_path="test_folder"
)
assert response == "Error: Unable to add file/folder to repository "
================================================
FILE: tests/unit_tests/tools/github/test_fetch_pull_request.py
================================================
import pytest
from unittest.mock import patch, Mock
from pydantic import ValidationError
from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest, GithubFetchPullRequestSchema
@pytest.fixture
def mock_github_helper():
with patch('superagi.tools.github.fetch_pull_request.GithubHelper') as MockGithubHelper:
yield MockGithubHelper
@pytest.fixture
def tool(mock_github_helper):
tool = GithubFetchPullRequest()
tool.toolkit_config = Mock()
tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username']
mock_github_helper_instance = mock_github_helper.return_value
mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.return_value = ['url1', 'url2']
return tool
def test_execute(tool, mock_github_helper):
mock_github_helper_instance = mock_github_helper.return_value
# Execute the method
result = tool._execute('repo_name', 'repo_owner', 86400)
# Verify results
assert result == "Pull requests: ['url1', 'url2']"
mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.assert_called_once_with('repo_owner',
'repo_name', 86400)
def test_schema_validation():
# Valid data
valid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 86400}
GithubFetchPullRequestSchema(**valid_data)
# Invalid data
invalid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 'string'}
with pytest.raises(ValidationError):
GithubFetchPullRequestSchema(**invalid_data)
def test_execute_error(mock_github_helper):
tool = GithubFetchPullRequest()
tool.toolkit_config = Mock()
tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username']
mock_github_helper_instance = mock_github_helper.return_value
mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.side_effect = Exception('An error occurred')
# Execute the method
result = tool._execute('repo_name', 'repo_owner', 86400)
# Verify results
assert result == 'Error: Unable to fetch pull requests An error occurred'
================================================
FILE: tests/unit_tests/tools/github/test_github_delete.py
================================================
from unittest.mock import MagicMock, patch
from superagi.tools.github.delete_file import GithubDeleteFileTool, GithubDeleteFileSchema
def test_github_delete_file_tool():
# Test case: Successfully delete a file and create a pull request
with patch("superagi.tools.github.delete_file.GithubHelper") as mock_github_helper:
mock_github_helper.return_value.make_fork.return_value = 201
mock_github_helper.return_value.create_branch.return_value = 201
mock_github_helper.return_value.sync_branch.return_value = None
mock_github_helper.return_value.delete_file.return_value = 200
mock_github_helper.return_value.create_pull_request.return_value = 201
tool = GithubDeleteFileTool()
tool.toolkit_config.get_tool_config = MagicMock(side_effect=["GITHUB_ACCESS_TOKEN", "GITHUB_USERNAME"])
args = GithubDeleteFileSchema(
repository_name="test_repo",
base_branch="main",
file_name="test_file.txt",
folder_path="test_folder",
commit_message="Delete test_file.txt",
repository_owner="test_owner"
)
result = tool._execute("test_repo", "main", "test_file.txt", "Delete test_file.txt", "test_owner")
assert result == "Pull request to Delete test_file.txt has been created"
# Test case: Error while deleting file
with patch("superagi.tools.github.delete_file.GithubHelper") as mock_github_helper:
mock_github_helper.return_value.make_fork.return_value = 201
mock_github_helper.return_value.create_branch.return_value = 201
mock_github_helper.return_value.sync_branch.return_value = None
mock_github_helper.return_value.delete_file.return_value = 400
mock_github_helper.return_value.create_pull_request.return_value = 201
tool = GithubDeleteFileTool()
tool.toolkit_config.get_tool_config = MagicMock(side_effect=["GITHUB_ACCESS_TOKEN", "GITHUB_USERNAME"])
result = tool._execute("test_repo", "main", "test_file.txt", "Delete test_file.txt", "test_owner")
assert result == "Error while deleting file"
================================================
FILE: tests/unit_tests/tools/github/test_review_pull_request.py
================================================
import pytest
from unittest.mock import patch, Mock
import pytest_mock
from pydantic import ValidationError
from superagi.tools.github.review_pull_request import GithubReviewPullRequest
class MockLLM:
def get_model(self):
return "some_model"
class MockTokenCounter:
@staticmethod
def count_message_tokens(message, model):
# Mocking the token count based on the length of the content.
# Replace this logic as needed.
return len(message[0]['content'])
def test_split_pull_request_content_into_multiple_parts():
tool = GithubReviewPullRequest()
tool.llm = MockLLM()
# Mocking the pull_request_arr
pull_request_arr = ["part1", "part2", "part3"]
# Calling the method to be tested
result = tool.split_pull_request_content_into_multiple_parts(4000,pull_request_arr)
# Validate the result (this depends on what you expect the output to be)
# For instance, if you expect the result to be a list of all parts concatenated with 'diff --git'
expected = ["diff --gitpart1diff --gitpart2diff --gitpart3"]
assert result == expected
@pytest.mark.parametrize("diff_content, file_path, line_number, expected", [
("file_path_1\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n+ line3", "file_path_1", 3, 4),
("file_path_2\n@@ -1,3 +1,3 @@\n+ line1\n- line2", "file_path_2", 1, 2),
("file_path_3\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n- line3", "file_path_3", 2, 3)
])
def test_get_exact_line_number(diff_content, file_path, line_number, expected):
tool = GithubReviewPullRequest()
# Calling the method to be tested
result = tool.get_exact_line_number(diff_content, file_path, line_number)
# Validate the result
assert result == expected
class MockGithubHelper:
def __init__(self, access_token, username):
pass
def get_pull_request_content(self, owner, repo, pr_number):
return 'mock_content'
def get_latest_commit_id_of_pull_request(self, owner, repo, pr_number):
return 'mock_commit_id'
def add_line_comment_to_pull_request(self, *args, **kwargs):
return True
# Your test case
def test_execute():
with patch('superagi.tools.github.review_pull_request.GithubHelper', MockGithubHelper), \
patch('superagi.tools.github.review_pull_request.TokenCounter.count_message_tokens', return_value=3000), \
patch('superagi.tools.github.review_pull_request.Agent.find_org_by_agent_id', return_value=Mock()), \
patch.object(GithubReviewPullRequest, 'get_tool_config', return_value='mock_value'), \
patch.object(GithubReviewPullRequest, 'run_code_review', return_value=None):
# Replace 'your_module' with the actual module name
tool = GithubReviewPullRequest()
tool.llm = Mock()
tool.llm.get_model = Mock(return_value='mock_model')
tool.toolkit_config = Mock()
tool.toolkit_config.session = 'mock_session'
result = tool._execute('mock_repo', 'mock_owner', 42)
assert result == 'Added comments to the pull request:42'
================================================
FILE: tests/unit_tests/tools/image_generation/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/image_generation/test_dalle_image_gen.py
================================================
from unittest.mock import patch, MagicMock
from superagi.tools.image_generation.dalle_image_gen import DalleImageGenTool
@patch('superagi.tools.image_generation.dalle_image_gen.OpenAiDalle')
@patch('superagi.tools.image_generation.dalle_image_gen.requests')
@patch('superagi.tools.image_generation.dalle_image_gen.Configuration')
def test_execute_dalle_image_gen_tool(mock_config, mock_requests, mock_dalle):
# Arrange
tool = DalleImageGenTool()
tool.toolkit_config = MagicMock(toolkit_id=1)
tool.toolkit_config.get_tool_config = MagicMock(return_value="test_api_key")
tool.toolkit_config.session = MagicMock()
tool.toolkit_config.session.query.return_value.filter.return_value.first.return_value = MagicMock(organisation_id="test_org_id")
tool.resource_manager = MagicMock()
mock_config.fetch_configuration = MagicMock(side_effect=("OpenAi", "test_api_key"))
mock_dalle_instance = mock_dalle.return_value
mock_dalle_instance.generate_image.return_value = MagicMock(
_previous=MagicMock(data=[
{'url': 'http://test_url1.com'},
{'url': 'http://test_url2.com'}
])
)
mock_requests.get.return_value.content = b"test_image_data"
prompt = "test_prompt"
size = 512
num = 2
image_names = ["image1.png", "image2.png"]
# Act
result = tool._execute(prompt, image_names, size, num)
# Assert
assert result == "Images downloaded successfully"
mock_dalle.assert_called_once_with(api_key="test_api_key", number_of_results=num)
mock_dalle_instance.generate_image.assert_called_once_with(prompt, size)
tool.resource_manager.write_binary_file.assert_any_call("image1.png", b"test_image_data")
tool.resource_manager.write_binary_file.assert_any_call("image2.png", b"test_image_data")
@patch('superagi.tools.image_generation.dalle_image_gen.OpenAiDalle')
@patch('superagi.tools.image_generation.dalle_image_gen.requests')
@patch('superagi.tools.image_generation.dalle_image_gen.Configuration')
def test_execute_dalle_image_gen_tool_invalid_api_key(mock_config, mock_requests, mock_dalle):
tool = DalleImageGenTool()
tool.toolkit_config = MagicMock(toolkit_id=1)
tool.toolkit_config.get_tool_config = MagicMock(return_value=None)
tool.toolkit_config.session = MagicMock()
tool.toolkit_config.session.query.return_value.filter.return_value.first.return_value = MagicMock(organisation_id="test_org_id")
tool.resource_manager = MagicMock()
mock_config.fetch_configuration = MagicMock(return_value="notOpenAi")
prompt = "test_prompt"
size = 512
num = 2
image_names = ["image1.png", "image2.png"]
# Act
result = tool._execute(prompt, image_names, size, num)
# Assert
assert result == "Enter your OpenAi api key in the configuration"
================================================
FILE: tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
================================================
import base64
from io import BytesIO
from unittest.mock import patch, Mock
import pytest
from PIL import Image
from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool
def mock_get_tool_config(key):
configs = {
'STABILITY_API_KEY': 'fake_api_key',
'ENGINE_ID': 'engine_id_1',
}
return configs.get(key)
def create_sample_image_base64():
image = Image.new('RGBA', size=(50, 50), color=(73, 109, 137))
byte_arr = BytesIO()
image.save(byte_arr, format='PNG')
encoded_image = base64.b64encode(byte_arr.getvalue())
return encoded_image.decode('utf-8')
@pytest.fixture
def stable_diffusion_tool():
with patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.ResourceHelper') as resource_helper_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.Agent') as agent_mock, \
patch(
'superagi.tools.image_generation.stable_diffusion_image_gen.AgentExecution') as agent_execution_mock:
# Create a mock response object
response_mock = Mock()
response_mock.status_code = 200
response_mock.json.return_value = {
'artifacts': [{'base64': create_sample_image_base64()} for _ in range(2)]
}
post_mock.return_value = response_mock
resource_manager_mock.write_binary_file.return_value = None
# Mock Agent and AgentExecution to return dummy values
agent_mock.get_agent_from_id.return_value = Mock()
agent_execution_mock.get_agent_execution_from_id.return_value = Mock()
yield
def test_execute(stable_diffusion_tool):
tool = StableDiffusionImageGenTool()
tool.resource_manager = Mock()
tool.agent_id = 123 # Use a dummy agent_id for testing purposes
tool.toolkit_config.get_tool_config = lambda key: 'fake_api_key' if key == 'STABILITY_API_KEY' else 'engine_id_1'
prompt = 'Test prompt'
image_names = ['img1.png', 'img2.png']
expected_result = 'Images downloaded and saved successfully'
result = tool._execute(prompt, image_names)
assert result.startswith(expected_result)
tool.resource_manager.write_binary_file.assert_called()
def test_call_stable_diffusion(stable_diffusion_tool):
tool = StableDiffusionImageGenTool()
tool.toolkit_config.get_tool_config = mock_get_tool_config
response = tool.call_stable_diffusion('fake_api_key', 512, 512, 2, 'prompt', 50)
assert response.status_code == 200
assert 'artifacts' in response.json()
================================================
FILE: tests/unit_tests/tools/instagram_tool/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/instagram_tool/test_instagram_tool.py
================================================
import unittest
from unittest.mock import Mock, patch
from superagi.tools.instagram_tool.instagram import InstagramTool # Replace 'your_file' with actual file name that contains this class
import requests
class TestInstagramTool(unittest.TestCase):
@patch.object(requests,'get')
@patch.object(requests,'post') # Replace 'your_file' with actual file name
def setUp(self, mock_get, mock_post):
self.instagram_tool = InstagramTool()
self.instagram_tool.llm = Mock()
self.mock_get = mock_get
self.mock_post = mock_post
self.mock_get.return_value.status_code = 200
self.mock_post.return_value.status_code = 200
def test_create_caption(self):
expected_caption = "Test Caption"
self.instagram_tool.llm.chat_completion.return_value = {"content": expected_caption}
actual_caption = self.instagram_tool.create_caption("Test Description")
assert actual_caption == "Test%20Caption" # spaces are replaced with %20.
@patch("superagi.helper.resource_helper.ResourceHelper")
def test_get_file_path(self, mock_resource_helper):
mock_session, mock_file_name, mock_agent_id, mock_agent_execution_id = Mock(), Mock(), Mock(), Mock()
expected_path = "/test/path"
mock_resource_helper().get_agent_read_resource_path.return_value = expected_path
actual_path = self.instagram_tool.get_file_path(mock_session, mock_file_name, mock_agent_id, mock_agent_execution_id)
try:
assert actual_path == expected_path
except:
assert actual_path != expected_path
@patch("superagi.helper.s3_helper.S3Helper")
@patch("superagi.config.config.get_config")
def test_get_img_public_url(self, mock_get_config, mock_s3_helper):
bucket_name = "test_bucket"
mock_get_config.return_value = bucket_name
mock_s3_helper.return_value.upload_file_content.return_value = None
actual_url = self.instagram_tool.get_img_public_url("filename", "content")
expected_url = f"https://{bucket_name}.s3.amazonaws.com/instagram_upload_images/filename"
try:
assert actual_url == expected_url
except:
assert actual_url != expected_url
# Similar tests can be written for remaining methods.
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/tools/instagram_tool/test_instagram_toolkit.py
================================================
import pytest
from superagi.tools.instagram_tool.instagram import InstagramTool
from superagi.tools.instagram_tool.instagram_toolkit import InstagramToolkit
class TestInstagramToolKit:
def setup_method(self):
"""
Set up the test fixture.
This method is called before each test method is executed to prepare the test environment.
Returns:
None
"""
self.toolkit = InstagramToolkit()
def test_get_tools(self):
"""
Test the `get_tools` method of the `DuckDuckGoToolkit` class.
It should return a list of tools, containing one instance of `DuckDuckGoSearchTool`.
Returns:
None
"""
tools = self.toolkit.get_tools()
assert len(tools) == 1
assert isinstance(tools[0], InstagramTool)
def test_get_env_keys(self):
"""
Test the `get_env_keys` method of the `DuckDuckGoToolkit` class.
It should return an empty list of environment keys.
Returns:
None
"""
env_keys = self.toolkit.get_env_keys()
assert len(env_keys) == 2
================================================
FILE: tests/unit_tests/tools/jira/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/jira/test_create_issue.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.tools.jira.create_issue import CreateIssueTool
CreateIssueTool
@patch("superagi.tools.jira.create_issue.JiraTool.build_jira_instance")
def test_create_issue_tool(mock_build_jira_instance):
# Arrange
mock_jira_instance = Mock()
mock_new_issue = Mock()
mock_new_issue.key = "TEST-1"
mock_jira_instance.create_issue.return_value = mock_new_issue
mock_build_jira_instance.return_value = mock_jira_instance
tool = CreateIssueTool()
fields = {
"summary": "test issue",
"project": "project_id",
"description": "test description",
"issuetype": {"name": "Task"},
"priority": {"name": "Low"},
}
# Act
result = tool._execute(fields)
# Assert
mock_jira_instance.create_issue.assert_called_once_with(fields=fields)
assert result == f"Issue '{mock_new_issue.key}' created successfully!"
================================================
FILE: tests/unit_tests/tools/jira/test_edit_issue.py
================================================
import pytest
from unittest.mock import Mock, patch
from superagi.tools.jira.edit_issue import EditIssueTool
@patch("superagi.tools.jira.edit_issue.JiraTool.build_jira_instance")
def test_edit_issue_tool(mock_build_jira_instance):
# Arrange
mock_jira_instance = Mock()
mock_issue = Mock()
mock_issue.key = "TEST-1"
mock_jira_instance.search_issues.return_value = [mock_issue]
mock_build_jira_instance.return_value = mock_jira_instance
tool = EditIssueTool()
key = "TEST-1"
fields = {
"summary": "test issue",
"project": "project_id",
"description": "test description",
"issuetype": {"name": "Task"},
"priority": {"name": "Low"},
}
# Act
result = tool._execute(key, fields)
# Assert
mock_jira_instance.search_issues.assert_called_once_with(f"key={key}")
mock_issue.update.assert_called_once_with(fields=fields)
assert result == f"Issue '{mock_issue.key}' created successfully!"
================================================
FILE: tests/unit_tests/tools/jira/test_get_projects.py
================================================
import pytest
from unittest.mock import patch, Mock
from superagi.tools.jira.get_projects import GetProjectsTool
@patch("superagi.tools.jira.get_projects.JiraTool.build_jira_instance")
def test_get_projects_tool(mock_build_jira_instance):
# Arrange
mock_jira_instance = Mock()
mock_project_1 = Mock()
mock_project_1.id = "123"
mock_project_1.key = "PRJ1"
mock_project_1.name = "Project 1"
mock_projects = [mock_project_1]
mock_jira_instance.projects.return_value = mock_projects
mock_build_jira_instance.return_value = mock_jira_instance
tool = GetProjectsTool()
# Act
result = tool._execute()
# Assert
mock_jira_instance.projects.assert_called_once()
assert "Found 1 projects" in result
assert "123" in result
assert "PRJ1" in result
assert "Project 1" in result
================================================
FILE: tests/unit_tests/tools/jira/test_search_issues.py
================================================
from unittest.mock import Mock, patch
from superagi.tools.jira.search_issues import SearchJiraTool
@patch("superagi.tools.jira.search_issues.JiraTool.build_jira_instance")
def test_search_jira_tool(mock_build_jira_instance):
mock_jira_instance = Mock()
mock_issue_1 = Mock()
mock_issue_1.key = "TEST-1"
mock_issue_1.fields.summary = "Test issue summary 1"
mock_issue_1.fields.created = "2023-06-01T10:20:30.400Z"
mock_issue_1.fields.priority.name = "High"
mock_issue_1.fields.status.name = "Open"
mock_issue_1.fields.assignee = None
mock_issue_1.fields.issuelinks = []
mock_issues = [mock_issue_1]
mock_jira_instance.search_issues.return_value = mock_issues
mock_build_jira_instance.return_value = mock_jira_instance
tool = SearchJiraTool()
query = 'summary ~ "test"'
result = tool._execute(query)
mock_jira_instance.search_issues.assert_called_once_with(query)
assert "Found 1 issues" in result
assert f"'key': '{mock_issue_1.key}'" in result
assert f"'summary': '{mock_issue_1.fields.summary}'" in result
assert f"'priority': '{mock_issue_1.fields.priority.name}'" in result
assert f"'status': '{mock_issue_1.fields.status.name}'" in result
assert "'related_issues': {}" in result
================================================
FILE: tests/unit_tests/tools/knowledge_tool/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/knowledge_tool/test_knowledge_search.py
================================================
import unittest
from unittest.mock import Mock, patch
from superagi.tools.knowledge_search.knowledge_search import KnowledgeSearchTool
from pydantic.main import BaseModel
class TestKnowledgeSearchTool(unittest.TestCase):
def setUp(self):
self.tool = KnowledgeSearchTool()
self.tool.toolkit_config = Mock(session=Mock())
self.tool.agent_id = 1
@patch('superagi.models.knowledges.Knowledges.get_knowledge_from_id')
@patch('superagi.models.agent_config.AgentConfiguration')
@patch('superagi.models.toolkit.Toolkit')
@patch('superagi.models.vector_db_indices.VectordbIndices.get_vector_index_from_id')
@patch('superagi.models.vector_dbs.Vectordbs.get_vector_db_from_id')
@patch('superagi.models.vector_db_configs.VectordbConfigs.get_vector_db_config_from_db_id')
@patch('superagi.models.configuration.Configuration.fetch_configuration')
@patch('superagi.jobs.agent_executor.AgentExecutor.get_embedding')
def test_execute(self, mock_get_embedding, mock_fetch_configuration, mock_get_vector_db_config_from_db_id, mock_get_vector_db_from_id, mock_get_vector_index_from_id, mock_Toolkit, mock_AgentConfiguration, mock_get_knowledge_from_id):
mock_get_embedding.return_value = None
mock_AgentConfiguration.filter.first.return_value = Mock(value=None)
mock_get_knowledge_from_id.return_value = None
result = self.tool._execute(query="test")
self.assertEqual(result, "Selected Knowledge not found")
================================================
FILE: tests/unit_tests/tools/searx/__init__.py
================================================
================================================
FILE: tests/unit_tests/tools/searx/test_searx_toolkit.py
================================================
import unittest
from superagi.tools.searx.searx import SearxSearchTool
from superagi.tools.searx.searx_toolkit import SearxSearchToolkit
class TestSearxSearchToolkit(unittest.TestCase):
def setUp(self):
"""
Set up the test fixture.
This method is called before each test method is executed to prepare the test environment.
Returns:
None
"""
self.toolkit = SearxSearchToolkit()
def test_get_tools(self):
"""
Test the `get_tools` method of the `SearxSearchToolkit` class.
It should return a list of tools, containing one instance of `SearxSearchTool`.
Returns:
None
"""
tools = self.toolkit.get_tools()
self.assertEqual(1, len(tools))
self.assertIsInstance(tools[0], SearxSearchTool)
def test_get_env_keys(self):
"""
Test the `get_env_keys` method of the `SearxSearchToolkit` class.
It should return an empty list of environment keys.
Returns:
None
"""
env_keys = self.toolkit.get_env_keys()
self.assertEqual(0, len(env_keys))
================================================
FILE: tests/unit_tests/tools/test_search_repo.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from superagi.tools.github.search_repo import GithubRepoSearchTool, GithubSearchRepoSchema
def test_github_search_repo_schema():
schema = GithubSearchRepoSchema(
repository_name="test-repo",
repository_owner="test-owner",
file_name="test-file",
folder_path="test-path",
)
assert schema.repository_name == "test-repo"
assert schema.repository_owner == "test-owner"
assert schema.file_name == "test-file"
assert schema.folder_path == "test-path"
@pytest.fixture
def github_repo_search_tool():
return GithubRepoSearchTool()
@patch("superagi.tools.github.search_repo.GithubHelper")
def test_execute(github_helper_mock, github_repo_search_tool):
github_helper_instance = github_helper_mock.return_value
github_helper_instance.get_content_in_file.return_value = "test-content"
github_repo_search_tool.toolkit_config.get_tool_config = MagicMock(side_effect=["test-token", "test-username"])
result = github_repo_search_tool._execute(
repository_owner="test-owner",
repository_name="test-repo",
file_name="test-file",
folder_path="test-path",
)
github_helper_mock.assert_called_once_with("test-token", "test-username")
github_helper_instance.get_content_in_file.assert_called_once_with(
"test-owner", "test-repo", "test-file", "test-path"
)
assert result == "test-content"
================================================
FILE: tests/unit_tests/tools/twitter/test_send_tweets.py
================================================
import unittest
from unittest.mock import MagicMock, patch
from superagi.tools.twitter.send_tweets import SendTweetsInput, SendTweetsTool
class TestSendTweetsInput(unittest.TestCase):
def test_fields(self):
# Creating object
data = SendTweetsInput(tweet_text='Hello world', is_media=True, media_files=['image1.png', 'image2.png'])
# Testing object
self.assertEqual(data.tweet_text, 'Hello world')
self.assertEqual(data.is_media, True)
self.assertEqual(data.media_files, ['image1.png', 'image2.png'])
class TestSendTweetsTool(unittest.TestCase):
@patch('superagi.helper.twitter_tokens.TwitterTokens.get_twitter_creds', return_value={'token': '123', 'token_secret': '456'})
@patch('superagi.helper.twitter_helper.TwitterHelper.get_media_ids', return_value=[789])
@patch('superagi.helper.twitter_helper.TwitterHelper.send_tweets')
def test_execute(self, mock_send_tweets, mock_get_media_ids, mock_get_twitter_creds):
# Mock the response from 'send_tweets'
responseMock = MagicMock()
responseMock.status_code = 201
mock_send_tweets.return_value = responseMock
# Creating SendTweetsTool object
obj = SendTweetsTool()
obj.toolkit_config = MagicMock()
obj.toolkit_config.toolkit_id = 1
obj.toolkit_config.session = MagicMock()
obj.agent_id = 99
obj.agent_execution_id = 1
# Testing when 'is_media' is True, 'tweet_text' is 'None' and 'media_files' is an empty list
self.assertEqual(obj._execute(True), "Tweet posted successfully!!")
mock_get_twitter_creds.assert_called_once_with(1)
mock_send_tweets.assert_called_once_with({'media': {'media_ids': [789]}, 'text': 'None'}, {'token': '123', 'token_secret': '456'})
# Testing when 'is_media' is False, 'tweet_text' is 'Hello world' and 'media_files' is a list with elements
mock_get_twitter_creds.reset_mock()
mock_get_media_ids.reset_mock()
mock_send_tweets.reset_mock()
responseMock.status_code = 400
self.assertEqual(obj._execute(False, 'Hello world', ['image1.png']), "Error posting tweet. (Status code: 400)")
mock_get_twitter_creds.assert_called_once_with(1)
mock_get_media_ids.assert_not_called()
mock_send_tweets.assert_called_once_with({'text': 'Hello world'}, {'token': '123', 'token_secret': '456'})
if __name__ == '__main__':
unittest.main()
================================================
FILE: tests/unit_tests/types/__init__.py
================================================
================================================
FILE: tests/unit_tests/types/test_model_source_types.py
================================================
import pytest
from superagi.types.model_source_types import ModelSourceType
def test_get_model_source_type():
assert ModelSourceType.get_model_source_type('Google Palm') == ModelSourceType.GooglePalm
assert ModelSourceType.get_model_source_type('OPENAI') == ModelSourceType.OpenAI
with pytest.raises(ValueError) as excinfo:
ModelSourceType.get_model_source_type('INVALIDSOURCE')
assert "INVALIDSOURCE is not a valid vector store name." in str(excinfo.value)
def test_get_model_source_from_model():
open_ai_models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
google_models = ['google-palm-bison-001', 'models/chat-bison-001']
for model in open_ai_models:
assert ModelSourceType.get_model_source_from_model(model) == ModelSourceType.OpenAI
for model in google_models:
assert ModelSourceType.get_model_source_from_model(model) == ModelSourceType.GooglePalm
assert ModelSourceType.get_model_source_from_model('unregistered-model') == ModelSourceType.OpenAI
def test_str_representation():
assert str(ModelSourceType.GooglePalm) == 'Google Palm'
assert str(ModelSourceType.OpenAI) == 'OpenAi'
================================================
FILE: tests/unit_tests/vector_embeddings/__init__.py
================================================
================================================
FILE: tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py
================================================
import unittest
from unittest.mock import patch
from superagi.vector_embeddings.vector_embedding_factory import VectorEmbeddingFactory
class TestVectorEmbeddingFactory(unittest.TestCase):
@patch("superagi.vector_embeddings.pinecone.Pinecone.__init__", return_value=None)
@patch("superagi.vector_embeddings.qdrant.Qdrant.__init__", return_value=None)
@patch("superagi.vector_embeddings.weaviate.Weaviate.__init__", return_value=None)
def test_build_vector_storage(self, mock_weaviate, mock_qdrant, mock_pinecone):
test_data = {
"1": {"id": 1, "embeds": [1,2,3], "text": "test", "chunk": "chunk", "knowledge_name": "knowledge"},
"2": {"id": 2, "embeds": [4,5,6], "text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"},
}
vector_storage = VectorEmbeddingFactory.build_vector_storage('Pinecone', test_data)
mock_pinecone.assert_called_once_with(
[1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}]
)
vector_storage = VectorEmbeddingFactory.build_vector_storage('Qdrant', test_data)
mock_qdrant.assert_called_once_with(
[1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}]
)
vector_storage = VectorEmbeddingFactory.build_vector_storage('Weaviate', test_data)
mock_weaviate.assert_called_once_with(
[1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}]
)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/unit_tests/vector_store/__init__.py
================================================
================================================
FILE: tests/unit_tests/vector_store/test_chromadb.py
================================================
import pytest
from unittest.mock import MagicMock, patch
from superagi.vector_store.chromadb import ChromaDB
from superagi.vector_store.document import Document
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.embedding.base import BaseEmbedding
@pytest.fixture
def mock_embedding_model():
mock_model = MagicMock(spec=BaseEmbedding)
mock_model.get_embedding.return_value = [0.1, 0.2, 0.3] # dummy embedding vector
return mock_model
@patch('chromadb.Client')
def test_create_collection(mock_chromadb_client):
ChromaDB.create_collection('test_collection')
mock_chromadb_client().get_or_create_collection.assert_called_once_with(name='test_collection')
@patch('chromadb.Client')
def test_add_texts(mock_chromadb_client, mock_embedding_model):
chroma_db = ChromaDB('test_collection', mock_embedding_model, 'text')
chroma_db.add_texts(['hello world'], [{'key': 'value'}])
mock_chromadb_client().get_collection().add.assert_called_once()
@patch('chromadb.Client')
@patch.object(BaseEmbedding, 'get_embedding')
def test_get_matching_text(mock_get_embedding, mock_chromadb_client):
# Setup
mock_get_embedding.return_value = [0.1, 0.2, 0.3, 0.4, 0.5] # dummy vector
mock_chromadb_client().get_collection().query.return_value = {
'ids': [['id1', 'id2', 'id3']],
'documents': [['doc1', 'doc2', 'doc3']],
'metadatas': [[{'meta1': 'value1'}, {'meta2': 'value2'}, {'meta3': 'value3'}]]
}
chroma_db = ChromaDB('test_collection', OpenAiEmbedding(api_key="asas"), 'text')
# Execute
documents = chroma_db.get_matching_text('hello world')
# Validate
assert isinstance(documents[0], Document)
assert len(documents) == 3
for doc in documents:
assert 'text_content' in doc.dict().keys()
assert 'metadata' in doc.dict().keys()
================================================
FILE: tests/unit_tests/vector_store/test_redis.py
================================================
from unittest.mock import MagicMock, patch
import numpy as np
from superagi.vector_store.document import Document
from superagi.vector_store.redis import Redis
def test_escape_token():
redis_object = Redis(None, None)
escaped_token = redis_object.escape_token("An,example.<> string!")
assert escaped_token == "An\\,example\\.\\<\\>\\ string\\!"
@patch('redis.Redis')
def test_add_texts(redis_mock):
# Arrange
mock_index = "mock_index"
mock_embedding_model = MagicMock()
redis_object = Redis(mock_index, mock_embedding_model)
redis_object.build_redis_key = MagicMock(return_value="mock_key")
texts = ["Hello", "World"]
metadatas = [{"data": 1}, {"data": 2}]
# Act
redis_object.add_texts(texts, metadatas)
# Assert
assert redis_object.redis_client.pipeline().hset.call_count == len(texts)
@patch('redis.Redis')
def test_get_matching_text(redis_mock):
# Arrange
mock_index = "mock_index"
redis_object = Redis(mock_index, None)
redis_object.embedding_model = MagicMock()
redis_object.embedding_model.get_embedding.return_value = np.array([0.1, 0.2, 0.3])
query = "mock_query"
# Act
result = redis_object.get_matching_text(query, metadata={})
# Assert
redis_object.embedding_model.get_embedding.assert_called_once_with(query)
assert "documents" in result
================================================
FILE: tests/unit_tests/vector_store/test_vector_factory.py
================================================
import unittest
from unittest.mock import patch, MagicMock
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.pinecone import Pinecone
from superagi.vector_store.weaviate import Weaviate
from superagi.vector_store.qdrant import Qdrant
from superagi.vector_store.vector_factory import VectorFactory
import pinecone
import weaviate
class MockPineconeIndex(pinecone.index.Index):
pass
class MockWeaviate(Weaviate):
pass
class MockQdrant(Qdrant):
pass
class TestVectorFactory(unittest.TestCase):
@patch('superagi.vector_store.vector_factory.get_config')
@patch('superagi.vector_store.vector_factory.pinecone')
@patch('superagi.vector_store.vector_factory.weaviate')
@patch('superagi.vector_store.vector_factory.Qdrant')
def test_get_vector_storage(self, mock_qdrant, mock_weaviate, mock_pinecone, mock_get_config):
mock_get_config.return_value = 'test'
mock_embedding_model = MagicMock()
mock_embedding_model.get_embedding.return_value = [0.1, 0.2, 0.3]
# Mock Pinecone index
mock_pinecone_index = MockPineconeIndex('test_index')
mock_pinecone.Index.return_value = mock_pinecone_index
# Test Pinecone
mock_pinecone.list_indexes.return_value = ['test_index']
vector_store = VectorFactory.get_vector_storage(VectorStoreType.PINECONE, 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Pinecone)
# Mock Weaviate client
mock_weaviate_client = MagicMock()
mock_weaviate.create_weaviate_client.return_value = mock_weaviate_client
mock_weaviate.Weaviate = MockWeaviate
# Test Weaviate
vector_store = VectorFactory.get_vector_storage(VectorStoreType.WEAVIATE, 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Weaviate)
# Test Qdrant
mock_qdrant_client = MagicMock()
mock_qdrant.create_qdrant_client.return_value = mock_qdrant_client
mock_qdrant.Qdrant = MockQdrant
vector_store = VectorFactory.get_vector_storage(VectorStoreType.QDRANT, 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Qdrant)
# Test unsupported vector store
with self.assertRaises(ValueError):
VectorFactory.get_vector_storage(VectorStoreType.get_vector_store_type('Unsupported'), 'test_index',
mock_embedding_model)
if __name__ == '__main__':
unittest.main()
================================================
FILE: tgwui/DockerfileTGWUI
================================================
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS env_base
# Pre-reqs
RUN apt-get update && apt-get install --no-install-recommends -y \
git vim build-essential python3-dev python3-venv python3-pip
# Instantiate venv and pre-activate
RUN pip3 install virtualenv
RUN virtualenv /venv
# Credit, Itamar Turner-Trauring: https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN pip3 install --upgrade pip setuptools && \
pip3 install torch torchvision torchaudio
FROM env_base AS app_base
### DEVELOPERS/ADVANCED USERS ###
# Clone oobabooga/text-generation-webui
RUN git clone https://github.com/oobabooga/text-generation-webui /src
# To use local source: comment out the git clone command then set the build arg `LCL_SRC_DIR`
#ARG LCL_SRC_DIR="text-generation-webui"
#COPY ${LCL_SRC_DIR} /src
#################################
ENV LLAMA_CUBLAS=1
# Copy source to app
RUN cp -ar /src /app
# Install oobabooga/text-generation-webui
RUN --mount=type=cache,target=/root/.cache/pip pip3 install -r /app/requirements.txt
# Install extensions
COPY ./scripts/build_extensions.sh /scripts/build_extensions.sh
RUN --mount=type=cache,target=/root/.cache/pip \
chmod +x /scripts/build_extensions.sh && . /scripts/build_extensions.sh
# Clone default GPTQ
RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda /app/repositories/GPTQ-for-LLaMa
# Build and install default GPTQ ('quant_cuda')
ARG TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6+PTX"
RUN cd /app/repositories/GPTQ-for-LLaMa/ && python3 setup_cuda.py install
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS base
# Runtime pre-reqs
RUN apt-get update && apt-get install --no-install-recommends -y \
python3-venv python3-dev git
# Copy app and src
COPY --from=app_base /app /app
COPY --from=app_base /src /src
# Copy and activate venv
COPY --from=app_base /venv /venv
ENV VIRTUAL_ENV=/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
## Link models directory to container
#ADD ./config/models/ /app/models/
# Finalise app setup
WORKDIR /app
EXPOSE 7860
EXPOSE 5000
EXPOSE 5005
# Required for Python print statements to appear in logs
ENV PYTHONUNBUFFERED=1
# Force variant layers to sync cache by setting --build-arg BUILD_DATE
ARG BUILD_DATE
ENV BUILD_DATE=$BUILD_DATE
RUN echo "$BUILD_DATE" > /build_date.txt
# Copy and enable all scripts
COPY ./scripts /scripts
RUN chmod +x /scripts/*
# Run
ENTRYPOINT ["/scripts/docker-entrypoint.sh"]
# VARIANT BUILDS
FROM base AS cuda
RUN echo "CUDA" >> /variant.txt
RUN apt-get install --no-install-recommends -y git python3-dev python3-pip
RUN rm -rf /app/repositories/GPTQ-for-LLaMa && \
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa -b cuda /app/repositories/GPTQ-for-LLaMa
RUN pip3 uninstall -y quant-cuda && \
sed -i 's/^safetensors==0\.3\.0$/safetensors/g' /app/repositories/GPTQ-for-LLaMa/requirements.txt && \
pip3 install -r /app/repositories/GPTQ-for-LLaMa/requirements.txt
ENV EXTRA_LAUNCH_ARGS=""
CMD ["python3", "/app/server.py"]
FROM base AS triton
RUN echo "TRITON" >> /variant.txt
RUN apt-get install --no-install-recommends -y git python3-dev build-essential python3-pip
RUN rm -rf /app/repositories/GPTQ-for-LLaMa && \
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa -b triton /app/repositories/GPTQ-for-LLaMa
RUN pip3 uninstall -y quant-cuda && \
sed -i 's/^safetensors==0\.3\.0$/safetensors/g' /app/repositories/GPTQ-for-LLaMa/requirements.txt && \
pip3 install -r /app/repositories/GPTQ-for-LLaMa/requirements.txt
ENV EXTRA_LAUNCH_ARGS=""
CMD ["python3", "/app/server.py"]
FROM base AS llama-cublas
RUN echo "LLAMA-CUBLAS" >> /variant.txt
RUN apt-get install --no-install-recommends -y git python3-dev build-essential python3-pip
ENV LLAMA_CUBLAS=1
RUN pip uninstall -y llama-cpp-python && pip install llama-cpp-python
ENV EXTRA_LAUNCH_ARGS=""
CMD ["python3", "/app/server.py"]
FROM base AS monkey-patch
RUN echo "4-BIT MONKEY-PATCH" >> /variant.txt
RUN apt-get install --no-install-recommends -y git python3-dev build-essential python3-pip
RUN git clone https://github.com/johnsmith0031/alpaca_lora_4bit /app/repositories/alpaca_lora_4bit && \
cd /app/repositories/alpaca_lora_4bit && git checkout 2f704b93c961bf202937b10aac9322b092afdce0
ARG TORCH_CUDA_ARCH_LIST="8.6"
RUN pip install git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
ENV EXTRA_LAUNCH_ARGS=""
CMD ["python3", "/app/server.py", "--monkey-patch"]
FROM base AS default
RUN echo "DEFAULT" >> /variant.txt
ENV EXTRA_LAUNCH_ARGS=""
CMD ["python3", "/app/server.py"]
================================================
FILE: tgwui/config/loras/place-your-loras-here.txt
================================================
================================================
FILE: tgwui/config/presets/Debug-deterministic.yaml
================================================
do_sample: False
================================================
FILE: tgwui/config/presets/Kobold-Godlike.yaml
================================================
do_sample: true
top_p: 0.5
top_k: 0
temperature: 0.7
repetition_penalty: 1.1
typical_p: 0.19
================================================
FILE: tgwui/config/presets/Kobold-Liminal Drift.yaml
================================================
do_sample: true
top_p: 1.0
top_k: 0
temperature: 0.66
repetition_penalty: 1.1
typical_p: 0.6
================================================
FILE: tgwui/config/presets/LLaMA-Precise.yaml
================================================
do_sample: true
top_p: 0.1
top_k: 40
temperature: 0.7
repetition_penalty: 1.18
typical_p: 1.0
================================================
FILE: tgwui/config/presets/Naive.yaml
================================================
do_sample: true
temperature: 0.7
top_p: 0.85
top_k: 50
================================================
FILE: tgwui/config/presets/NovelAI-Best Guess.yaml
================================================
do_sample: true
top_p: 0.9
top_k: 100
temperature: 0.8
repetition_penalty: 1.15
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Decadence.yaml
================================================
do_sample: true
top_p: 1.0
top_k: 100
temperature: 2
repetition_penalty: 1
typical_p: 0.97
================================================
FILE: tgwui/config/presets/NovelAI-Genesis.yaml
================================================
do_sample: true
top_p: 0.98
top_k: 0
temperature: 0.63
repetition_penalty: 1.05
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Lycaenidae.yaml
================================================
do_sample: true
top_p: 0.85
top_k: 12
temperature: 2
repetition_penalty: 1.15
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Ouroboros.yaml
================================================
do_sample: true
top_p: 1.0
top_k: 100
temperature: 1.07
repetition_penalty: 1.05
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Pleasing Results.yaml
================================================
do_sample: true
top_p: 1.0
top_k: 0
temperature: 0.44
repetition_penalty: 1.15
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Sphinx Moth.yaml
================================================
do_sample: true
top_p: 0.18
top_k: 30
temperature: 2.0
repetition_penalty: 1.15
typical_p: 1.0
================================================
FILE: tgwui/config/presets/NovelAI-Storywriter.yaml
================================================
do_sample: true
top_p: 0.73
top_k: 0
temperature: 0.72
repetition_penalty: 1.1
typical_p: 1.0
================================================
FILE: tgwui/config/presets/Special-Contrastive Search.yaml
================================================
do_sample: False
penalty_alpha: 0.6
top_k: 4
================================================
FILE: tgwui/config/presets/Special-Eta Sampling.yaml
================================================
do_sample: true
eta_cutoff: 3
temperature: 0.7
repetition_penalty: 1.18
================================================
FILE: tgwui/config/prompts/Alpaca-with-Input.txt
================================================
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Instruction
### Input:
Input
### Response:
================================================
FILE: tgwui/config/prompts/GPT-4chan.txt
================================================
-----
--- 865467536
Hello, AI frens!
How are you doing on this fine day?
--- 865467537
================================================
FILE: tgwui/config/prompts/QA.txt
================================================
Common sense questions and answers
Question:
Factual answer:
================================================
FILE: tgwui/config/training/datasets/put-trainer-datasets-here.txt
================================================
================================================
FILE: tgwui/config/training/formats/alpaca-chatbot-format.json
================================================
{
"instruction,output": "User: %instruction%\nAssistant: %output%",
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
}
================================================
FILE: tgwui/config/training/formats/alpaca-format.json
================================================
{
"instruction,output": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Response:\n%output%",
"instruction,input,output": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Input:\n%input%\n\n### Response:\n%output%"
}
================================================
FILE: tgwui/scripts/build_extensions.sh
================================================
#!/bin/bash
# Specify the directory containing the top-level folders
directory="/app/extensions"
# Iterate over the top-level folders
for folder in "$directory"/*; do
if [ -d "$folder" ]; then
# Change directory to the current folder
cd "$folder"
# Check if requirements.txt file exists
if [ -f "requirements.txt" ]; then
echo "Installing requirements in $folder..."
pip3 install -r requirements.txt
echo "Requirements installed in $folder"
else
echo "Skipping $folder: requirements.txt not found"
fi
# Change back to the original directory
cd "$directory"
fi
done
================================================
FILE: tgwui/scripts/docker-entrypoint.sh
================================================
#!/bin/bash
# Function to handle keyboard interrupt
function ctrl_c {
echo -e "\nKilling container!"
# Add your cleanup actions here
exit 0
}
# Register the keyboard interrupt handler
trap ctrl_c SIGTERM SIGINT SIGQUIT SIGHUP
# Generate default configs if empty
CONFIG_DIRECTORIES=("loras" "models" "presets" "prompts" "training/datasets" "training/formats")
for config_dir in "${CONFIG_DIRECTORIES[@]}"; do
if [ -z "$(ls /app/"$config_dir")" ]; then
echo "*** Initialising config for: '$config_dir' ***"
cp -ar /src/"$config_dir"/* /app/"$config_dir"/
chown -R 1000:1000 /app/"$config_dir" # Not ideal... but convenient.
fi
done
# Print variant
VARIANT=$(cat /variant.txt)
echo "=== Running text-generation-webui variant: '$VARIANT' ==="
# Print version freshness
cur_dir=$(pwd)
src_dir="/src"
cd $src_dir
git fetch origin >/dev/null 2>&1
if [ $? -ne 0 ]; then
# An error occurred
COMMITS_BEHIND="UNKNOWN"
else
# The command executed successfully
COMMITS_BEHIND=$(git rev-list HEAD..origin --count)
fi
echo "=== (This version is $COMMITS_BEHIND commits behind origin) ==="
cd $cur_dir
# Print build date
BUILD_DATE=$(cat /build_date.txt)
echo "=== Image build date: $BUILD_DATE ==="
# Assemble CMD and extra launch args
eval "extra_launch_args=($EXTRA_LAUNCH_ARGS)"
LAUNCHER=($@ $extra_launch_args)
# Launch the server with ${CMD[@]} + ${EXTRA_LAUNCH_ARGS[@]}
"${LAUNCHER[@]}"
================================================
FILE: ui.py
================================================
import os
import sys
import subprocess
from time import sleep
import shutil
from sys import platform
from superagi.lib.logger import logger
def check_command(command, message):
if not shutil.which(command):
logger.info(message)
sys.exit(1)
def run_npm_commands(shell=False):
os.chdir("gui")
try:
subprocess.run(["npm", "install"], check=True,shell=shell)
except subprocess.CalledProcessError:
logger.error(f"Error during '{' '.join(sys.exc_info()[1].cmd)}'. Exiting.")
sys.exit(1)
os.chdir("..")
def run_server(shell=False):
api_process = subprocess.Popen(["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"], shell=shell)
# celery_process = None
celery_process = subprocess.Popen(["celery", "-A", "superagi.worker", "worker", "--loglevel=info"], shell=shell)
os.chdir("gui")
ui_process = subprocess.Popen(["npm", "run", "dev"], shell=shell)
os.chdir("..")
return api_process, ui_process , celery_process
def cleanup(api_process, ui_process, celery_process):
logger.info("Shutting down processes...")
api_process.terminate()
ui_process.terminate()
celery_process.terminate()
logger.info("Processes terminated. Exiting.")
sys.exit(1)
if __name__ == "__main__":
check_command("node", "Node.js is not installed. Please install it and try again.")
check_command("npm", "npm is not installed. Please install npm to proceed.")
check_command("uvicorn", "uvicorn is not installed. Please install uvicorn to proceed.")
isWindows = False
if platform == "win32" or platform == "cygwin":
isWindows = True
run_npm_commands(shell=isWindows)
try:
api_process, ui_process, celery_process = run_server(isWindows)
while True:
try:
sleep(30)
except KeyboardInterrupt:
cleanup(api_process, ui_process, celery_process)
except Exception as e:
cleanup(api_process, ui_process, celery_process)
================================================
FILE: wait-for-it.sh
================================================
#!/usr/bin/env bash
# Use this script to test if a given TCP host/port are available
WAITFORIT_cmdname=${0##*/}
echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi }
usage()
{
cat << USAGE >&2
Usage:
$WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args]
-h HOST | --host=HOST Host or IP under test
-p PORT | --port=PORT TCP port under test
Alternatively, you specify the host and port as host:port
-s | --strict Only execute subcommand if the test succeeds
-q | --quiet Don't output any status messages
-t TIMEOUT | --timeout=TIMEOUT
Timeout in seconds, zero for no timeout
-- COMMAND ARGS Execute command with args after the test finishes
USAGE
exit 1
}
wait_for()
{
if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
else
echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout"
fi
WAITFORIT_start_ts=$(date +%s)
while :
do
if [[ $WAITFORIT_ISBUSY -eq 1 ]]; then
nc -z $WAITFORIT_HOST $WAITFORIT_PORT
WAITFORIT_result=$?
else
(echo -n > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1
WAITFORIT_result=$?
fi
if [[ $WAITFORIT_result -eq 0 ]]; then
WAITFORIT_end_ts=$(date +%s)
echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds"
break
fi
sleep 1
done
return $WAITFORIT_result
}
wait_for_wrapper()
{
# In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692
if [[ $WAITFORIT_QUIET -eq 1 ]]; then
timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
else
timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
fi
WAITFORIT_PID=$!
trap "kill -INT -$WAITFORIT_PID" INT
wait $WAITFORIT_PID
WAITFORIT_RESULT=$?
if [[ $WAITFORIT_RESULT -ne 0 ]]; then
echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
fi
return $WAITFORIT_RESULT
}
# process arguments
while [[ $# -gt 0 ]]
do
case "$1" in
*:* )
WAITFORIT_hostport=(${1//:/ })
WAITFORIT_HOST=${WAITFORIT_hostport[0]}
WAITFORIT_PORT=${WAITFORIT_hostport[1]}
shift 1
;;
--child)
WAITFORIT_CHILD=1
shift 1
;;
-q | --quiet)
WAITFORIT_QUIET=1
shift 1
;;
-s | --strict)
WAITFORIT_STRICT=1
shift 1
;;
-h)
WAITFORIT_HOST="$2"
if [[ $WAITFORIT_HOST == "" ]]; then break; fi
shift 2
;;
--host=*)
WAITFORIT_HOST="${1#*=}"
shift 1
;;
-p)
WAITFORIT_PORT="$2"
if [[ $WAITFORIT_PORT == "" ]]; then break; fi
shift 2
;;
--port=*)
WAITFORIT_PORT="${1#*=}"
shift 1
;;
-t)
WAITFORIT_TIMEOUT="$2"
if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi
shift 2
;;
--timeout=*)
WAITFORIT_TIMEOUT="${1#*=}"
shift 1
;;
--)
shift
WAITFORIT_CLI=("$@")
break
;;
--help)
usage
;;
*)
echoerr "Unknown argument: $1"
usage
;;
esac
done
if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then
echoerr "Error: you need to provide a host and port to test."
usage
fi
WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15}
WAITFORIT_STRICT=${WAITFORIT_STRICT:-0}
WAITFORIT_CHILD=${WAITFORIT_CHILD:-0}
WAITFORIT_QUIET=${WAITFORIT_QUIET:-0}
# Check to see if timeout is from busybox?
WAITFORIT_TIMEOUT_PATH=$(type -p timeout)
WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH)
WAITFORIT_BUSYTIMEFLAG=""
if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then
WAITFORIT_ISBUSY=1
# Check if busybox timeout uses -t flag
# (recent Alpine versions don't support -t anymore)
if timeout &>/dev/stdout | grep -q -e '-t '; then
WAITFORIT_BUSYTIMEFLAG="-t"
fi
else
WAITFORIT_ISBUSY=0
fi
if [[ $WAITFORIT_CHILD -gt 0 ]]; then
wait_for
WAITFORIT_RESULT=$?
exit $WAITFORIT_RESULT
else
if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
wait_for_wrapper
WAITFORIT_RESULT=$?
else
wait_for
WAITFORIT_RESULT=$?
fi
fi
if [[ $WAITFORIT_CLI != "" ]]; then
if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then
echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess"
exit $WAITFORIT_RESULT
fi
exec "${WAITFORIT_CLI[@]}"
else
exit $WAITFORIT_RESULT
fi