Repository: cohere-ai/cohere-python Branch: main Commit: 756b1d8ec0e4 Files: 367 Total size: 1.9 MB Directory structure: gitextract_zznh6amy/ ├── .fern/ │ └── metadata.json ├── .fernignore ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── improvement_request.md │ └── workflows/ │ └── ci.yml ├── .gitignore ├── 4.0.0-5.0.0-migration-guide.md ├── LICENSE ├── README.md ├── mypy.ini ├── pyproject.toml ├── reference.md ├── requirements.txt ├── src/ │ └── cohere/ │ ├── __init__.py │ ├── _default_clients.py │ ├── aliases.py │ ├── audio/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── transcriptions/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── types/ │ │ ├── __init__.py │ │ └── audio_transcriptions_create_response.py │ ├── aws_client.py │ ├── base_client.py │ ├── batches/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── types/ │ │ ├── __init__.py │ │ ├── batch.py │ │ ├── batch_status.py │ │ ├── cancel_batch_response.py │ │ ├── create_batch_response.py │ │ ├── get_batch_response.py │ │ └── list_batches_response.py │ ├── bedrock_client.py │ ├── client.py │ ├── client_v2.py │ ├── config.py │ ├── connectors/ │ │ ├── __init__.py │ │ ├── client.py │ │ └── raw_client.py │ ├── core/ │ │ ├── __init__.py │ │ ├── api_error.py │ │ ├── client_wrapper.py │ │ ├── datetime_utils.py │ │ ├── file.py │ │ ├── force_multipart.py │ │ ├── http_client.py │ │ ├── http_response.py │ │ ├── http_sse/ │ │ │ ├── __init__.py │ │ │ ├── _api.py │ │ │ ├── _decoders.py │ │ │ ├── _exceptions.py │ │ │ └── _models.py │ │ ├── jsonable_encoder.py │ │ ├── logging.py │ │ ├── parse_error.py │ │ ├── pydantic_utilities.py │ │ ├── query_encoder.py │ │ ├── remove_none_from_dict.py │ │ ├── request_options.py │ │ ├── serialization.py │ │ └── unchecked_base_model.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── types/ │ │ ├── __init__.py │ │ ├── datasets_create_response.py │ │ ├── datasets_get_response.py │ │ ├── datasets_get_usage_response.py │ │ └── datasets_list_response.py │ ├── embed_jobs/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── types/ │ │ ├── __init__.py │ │ └── create_embed_job_request_truncate.py │ ├── environment.py │ ├── errors/ │ │ ├── __init__.py │ │ ├── bad_request_error.py │ │ ├── client_closed_request_error.py │ │ ├── forbidden_error.py │ │ ├── gateway_timeout_error.py │ │ ├── internal_server_error.py │ │ ├── invalid_token_error.py │ │ ├── not_found_error.py │ │ ├── not_implemented_error.py │ │ ├── service_unavailable_error.py │ │ ├── too_many_requests_error.py │ │ ├── unauthorized_error.py │ │ └── unprocessable_entity_error.py │ ├── finetuning/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── finetuning/ │ │ │ ├── __init__.py │ │ │ └── types/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── base_type.py │ │ │ ├── create_finetuned_model_response.py │ │ │ ├── delete_finetuned_model_response.py │ │ │ ├── event.py │ │ │ ├── finetuned_model.py │ │ │ ├── get_finetuned_model_response.py │ │ │ ├── hyperparameters.py │ │ │ ├── list_events_response.py │ │ │ ├── list_finetuned_models_response.py │ │ │ ├── list_training_step_metrics_response.py │ │ │ ├── lora_target_modules.py │ │ │ ├── settings.py │ │ │ ├── status.py │ │ │ ├── strategy.py │ │ │ ├── training_step_metrics.py │ │ │ ├── update_finetuned_model_response.py │ │ │ └── wandb_config.py │ │ └── raw_client.py │ ├── manually_maintained/ │ │ ├── __init__.py │ │ ├── cache.py │ │ ├── cohere_aws/ │ │ │ ├── __init__.py │ │ │ ├── chat.py │ │ │ ├── classification.py │ │ │ ├── client.py │ │ │ ├── embeddings.py │ │ │ ├── error.py │ │ │ ├── generation.py │ │ │ ├── mode.py │ │ │ ├── rerank.py │ │ │ ├── response.py │ │ │ └── summary.py │ │ ├── lazy_aws_deps.py │ │ ├── lazy_oci_deps.py │ │ ├── streaming_embed.py │ │ └── tokenizers.py │ ├── models/ │ │ ├── __init__.py │ │ ├── client.py │ │ └── raw_client.py │ ├── oci_client.py │ ├── overrides.py │ ├── py.typed │ ├── raw_base_client.py │ ├── sagemaker_client.py │ ├── types/ │ │ ├── __init__.py │ │ ├── api_meta.py │ │ ├── api_meta_api_version.py │ │ ├── api_meta_billed_units.py │ │ ├── api_meta_tokens.py │ │ ├── assistant_message.py │ │ ├── assistant_message_response.py │ │ ├── assistant_message_response_content_item.py │ │ ├── assistant_message_v2content.py │ │ ├── assistant_message_v2content_one_item.py │ │ ├── auth_token_type.py │ │ ├── chat_citation.py │ │ ├── chat_citation_generation_event.py │ │ ├── chat_citation_type.py │ │ ├── chat_connector.py │ │ ├── chat_content_delta_event.py │ │ ├── chat_content_delta_event_delta.py │ │ ├── chat_content_delta_event_delta_message.py │ │ ├── chat_content_delta_event_delta_message_content.py │ │ ├── chat_content_end_event.py │ │ ├── chat_content_start_event.py │ │ ├── chat_content_start_event_delta.py │ │ ├── chat_content_start_event_delta_message.py │ │ ├── chat_content_start_event_delta_message_content.py │ │ ├── chat_content_start_event_delta_message_content_type.py │ │ ├── chat_data_metrics.py │ │ ├── chat_debug_event.py │ │ ├── chat_document.py │ │ ├── chat_document_source.py │ │ ├── chat_finish_reason.py │ │ ├── chat_message.py │ │ ├── chat_message_end_event.py │ │ ├── chat_message_end_event_delta.py │ │ ├── chat_message_start_event.py │ │ ├── chat_message_start_event_delta.py │ │ ├── chat_message_start_event_delta_message.py │ │ ├── chat_message_v2.py │ │ ├── chat_messages.py │ │ ├── chat_request_citation_quality.py │ │ ├── chat_request_prompt_truncation.py │ │ ├── chat_request_safety_mode.py │ │ ├── chat_search_queries_generation_event.py │ │ ├── chat_search_query.py │ │ ├── chat_search_result.py │ │ ├── chat_search_result_connector.py │ │ ├── chat_search_results_event.py │ │ ├── chat_stream_end_event.py │ │ ├── chat_stream_end_event_finish_reason.py │ │ ├── chat_stream_event.py │ │ ├── chat_stream_event_type.py │ │ ├── chat_stream_request_citation_quality.py │ │ ├── chat_stream_request_prompt_truncation.py │ │ ├── chat_stream_request_safety_mode.py │ │ ├── chat_stream_start_event.py │ │ ├── chat_text_content.py │ │ ├── chat_text_generation_event.py │ │ ├── chat_text_response_format.py │ │ ├── chat_text_response_format_v2.py │ │ ├── chat_thinking_content.py │ │ ├── chat_tool_call_delta_event.py │ │ ├── chat_tool_call_delta_event_delta.py │ │ ├── chat_tool_call_delta_event_delta_message.py │ │ ├── chat_tool_call_delta_event_delta_message_tool_calls.py │ │ ├── chat_tool_call_delta_event_delta_message_tool_calls_function.py │ │ ├── chat_tool_call_end_event.py │ │ ├── chat_tool_call_start_event.py │ │ ├── chat_tool_call_start_event_delta.py │ │ ├── chat_tool_call_start_event_delta_message.py │ │ ├── chat_tool_calls_chunk_event.py │ │ ├── chat_tool_calls_generation_event.py │ │ ├── chat_tool_message.py │ │ ├── chat_tool_plan_delta_event.py │ │ ├── chat_tool_plan_delta_event_delta.py │ │ ├── chat_tool_plan_delta_event_delta_message.py │ │ ├── chat_tool_source.py │ │ ├── check_api_key_response.py │ │ ├── citation.py │ │ ├── citation_end_event.py │ │ ├── citation_options.py │ │ ├── citation_options_mode.py │ │ ├── citation_start_event.py │ │ ├── citation_start_event_delta.py │ │ ├── citation_start_event_delta_message.py │ │ ├── citation_type.py │ │ ├── classify_data_metrics.py │ │ ├── classify_example.py │ │ ├── classify_request_truncate.py │ │ ├── classify_response.py │ │ ├── classify_response_classifications_item.py │ │ ├── classify_response_classifications_item_classification_type.py │ │ ├── classify_response_classifications_item_labels_value.py │ │ ├── compatible_endpoint.py │ │ ├── connector.py │ │ ├── connector_auth_status.py │ │ ├── connector_o_auth.py │ │ ├── content.py │ │ ├── create_connector_o_auth.py │ │ ├── create_connector_response.py │ │ ├── create_connector_service_auth.py │ │ ├── create_embed_job_response.py │ │ ├── dataset.py │ │ ├── dataset_part.py │ │ ├── dataset_type.py │ │ ├── dataset_validation_status.py │ │ ├── delete_connector_response.py │ │ ├── detokenize_response.py │ │ ├── document.py │ │ ├── document_content.py │ │ ├── embed_by_type_response.py │ │ ├── embed_by_type_response_embeddings.py │ │ ├── embed_by_type_response_response_type.py │ │ ├── embed_content.py │ │ ├── embed_floats_response.py │ │ ├── embed_image.py │ │ ├── embed_image_url.py │ │ ├── embed_input.py │ │ ├── embed_input_type.py │ │ ├── embed_job.py │ │ ├── embed_job_status.py │ │ ├── embed_job_truncate.py │ │ ├── embed_request_truncate.py │ │ ├── embed_response.py │ │ ├── embed_text.py │ │ ├── embedding_type.py │ │ ├── finetune_dataset_metrics.py │ │ ├── finish_reason.py │ │ ├── generate_request_return_likelihoods.py │ │ ├── generate_request_truncate.py │ │ ├── generate_stream_end.py │ │ ├── generate_stream_end_response.py │ │ ├── generate_stream_error.py │ │ ├── generate_stream_event.py │ │ ├── generate_stream_request_return_likelihoods.py │ │ ├── generate_stream_request_truncate.py │ │ ├── generate_stream_text.py │ │ ├── generate_streamed_response.py │ │ ├── generation.py │ │ ├── get_connector_response.py │ │ ├── get_model_response.py │ │ ├── get_model_response_sampling_defaults.py │ │ ├── image.py │ │ ├── image_content.py │ │ ├── image_url.py │ │ ├── image_url_detail.py │ │ ├── json_response_format.py │ │ ├── json_response_format_v2.py │ │ ├── label_metric.py │ │ ├── list_connectors_response.py │ │ ├── list_embed_job_response.py │ │ ├── list_models_response.py │ │ ├── logprob_item.py │ │ ├── message.py │ │ ├── metrics.py │ │ ├── non_streamed_chat_response.py │ │ ├── o_auth_authorize_response.py │ │ ├── parse_info.py │ │ ├── rerank_document.py │ │ ├── rerank_request_documents_item.py │ │ ├── rerank_response.py │ │ ├── rerank_response_results_item.py │ │ ├── rerank_response_results_item_document.py │ │ ├── reranker_data_metrics.py │ │ ├── response_format.py │ │ ├── response_format_v2.py │ │ ├── single_generation.py │ │ ├── single_generation_in_stream.py │ │ ├── single_generation_token_likelihoods_item.py │ │ ├── source.py │ │ ├── streamed_chat_response.py │ │ ├── summarize_request_extractiveness.py │ │ ├── summarize_request_format.py │ │ ├── summarize_request_length.py │ │ ├── summarize_response.py │ │ ├── system_message_v2.py │ │ ├── system_message_v2content.py │ │ ├── system_message_v2content_one_item.py │ │ ├── thinking.py │ │ ├── thinking_type.py │ │ ├── tokenize_response.py │ │ ├── tool.py │ │ ├── tool_call.py │ │ ├── tool_call_delta.py │ │ ├── tool_call_v2.py │ │ ├── tool_call_v2function.py │ │ ├── tool_content.py │ │ ├── tool_message_v2.py │ │ ├── tool_message_v2content.py │ │ ├── tool_parameter_definitions_value.py │ │ ├── tool_result.py │ │ ├── tool_v2.py │ │ ├── tool_v2function.py │ │ ├── update_connector_response.py │ │ ├── usage.py │ │ ├── usage_billed_units.py │ │ ├── usage_tokens.py │ │ ├── user_message_v2.py │ │ └── user_message_v2content.py │ ├── utils.py │ ├── v2/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── raw_client.py │ │ └── types/ │ │ ├── __init__.py │ │ ├── v2chat_request_documents_item.py │ │ ├── v2chat_request_safety_mode.py │ │ ├── v2chat_request_tool_choice.py │ │ ├── v2chat_response.py │ │ ├── v2chat_stream_request_documents_item.py │ │ ├── v2chat_stream_request_safety_mode.py │ │ ├── v2chat_stream_request_tool_choice.py │ │ ├── v2chat_stream_response.py │ │ ├── v2embed_request_truncate.py │ │ ├── v2rerank_response.py │ │ └── v2rerank_response_results_item.py │ └── version.py └── tests/ ├── __init__.py ├── embed_job.jsonl ├── test_async_client.py ├── test_aws_client_unit.py ├── test_bedrock_client.py ├── test_client.py ├── test_client_init.py ├── test_client_v2.py ├── test_embed_streaming.py ├── test_embed_utils.py ├── test_oci_client.py ├── test_oci_mypy.py └── test_overrides.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .fern/metadata.json ================================================ { "cliVersion": "4.63.2", "generatorName": "fernapi/fern-python-sdk", "generatorVersion": "5.3.3", "generatorConfig": { "inline_request_params": false, "extras": { "oci": [ "oci" ] }, "extra_dependencies": { "fastavro": "^1.9.4", "requests": "^2.0.0", "types-requests": "^2.0.0", "tokenizers": ">=0.15,<1", "oci": { "version": "^2.165.0", "optional": true } }, "improved_imports": true, "pydantic_config": { "frozen": false, "union_naming": "v1", "require_optional_fields": false, "extra_fields": "allow", "use_str_enums": true, "skip_validation": true }, "timeout_in_seconds": 300, "client": { "class_name": "BaseCohere", "filename": "base_client.py", "exported_class_name": "Client", "exported_filename": "client.py" }, "additional_init_exports": [ { "from": "client", "imports": [ "Client", "AsyncClient" ] }, { "from": "bedrock_client", "imports": [ "BedrockClient", "BedrockClientV2" ] }, { "from": "sagemaker_client", "imports": [ "SagemakerClient", "SagemakerClientV2" ] }, { "from": "aws_client", "imports": [ "AwsClient" ] }, { "from": "oci_client", "imports": [ "OciClient", "OciClientV2" ] }, { "from": "client_v2", "imports": [ "AsyncClientV2", "ClientV2" ] }, { "from": "aliases", "imports": [ "StreamedChatResponseV2", "MessageStartStreamedChatResponseV2", "MessageEndStreamedChatResponseV2", "ContentStartStreamedChatResponseV2", "ContentDeltaStreamedChatResponseV2", "ContentEndStreamedChatResponseV2", "ToolCallStartStreamedChatResponseV2", "ToolCallDeltaStreamedChatResponseV2", "ToolCallEndStreamedChatResponseV2", "ChatResponse" ] } ] }, "originGitCommit": "8dfb5e03f14a05967c4cdeeb44429eb4c1dca198", "sdkVersion": "6.1.0" } ================================================ FILE: .fernignore ================================================ 4.0.0-5.0.0-migration-guide.md banner.png README.md src/cohere/client.py tests .github/workflows/ci.yml .github/ISSUE_TEMPLATE LICENSE .github/workflows/tests.yml src/cohere/utils.py src/cohere/overrides.py src/cohere/config.py src/cohere/manually_maintained src/cohere/manually_maintained/__init__.py src/cohere/bedrock_client.py src/cohere/aws_client.py src/cohere/sagemaker_client.py src/cohere/oci_client.py src/cohere/client_v2.py mypy.ini src/cohere/aliases.py ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report related to an SDK error about: Create a report to help us improve title: '' labels: '' --- **SDK Version (required)** Provide the version you are using. To get the version, run the following python snippet ```python import cohere print(cohere.__version__) # 5.6.1 ``` **Describe the bug** A clear and concise description of what the bug is. **Screenshots** If applicable, add screenshots to help explain your problem. ================================================ FILE: .github/ISSUE_TEMPLATE/improvement_request.md ================================================ --- name: Improvement request, or addition features about: Create a request to help us improve title: "" labels: "" --- **Describe the improvement** A clear and concise description of what the new improvement is. **Code snippet of expected outcome** If applicable, add a code snippet of how you'd like to see the feature implemented ================================================ FILE: .github/workflows/ci.yml ================================================ name: ci on: [push] jobs: compile: runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: "3.10" - name: Bootstrap poetry uses: snok/install-poetry@v1 with: version: 1.5.1 virtualenvs-in-project: false - name: Install dependencies run: poetry install - name: Compile run: poetry run mypy . test: runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: "3.10" - name: Bootstrap poetry uses: snok/install-poetry@v1 with: version: 1.5.1 virtualenvs-in-project: false - name: Install dependencies run: poetry install - name: Install aws deps run: poetry run pip install boto3 sagemaker botocore - name: Test run: poetry run pytest -rP -n auto . env: CO_API_KEY: ${{ secrets.COHERE_API_KEY }} - name: Install aiohttp extra run: poetry install --extras aiohttp - name: Test (aiohttp) run: poetry run pytest -rP -n auto -m aiohttp . || [ $? -eq 5 ] env: CO_API_KEY: ${{ secrets.COHERE_API_KEY }} publish: needs: [compile, test] if: github.event_name == 'push' && contains(github.ref, 'refs/tags/') runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: "3.10" - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 - name: Install dependencies run: poetry install - name: Publish to pypi run: | poetry config repositories.remote https://upload.pypi.org/legacy/ poetry --no-interaction -v publish --build --repository remote --username "$PYPI_USERNAME" --password "$PYPI_PASSWORD" env: PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }} PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} ================================================ FILE: .gitignore ================================================ .mypy_cache/ .ruff_cache/ __pycache__/ dist/ poetry.toml ================================================ FILE: 4.0.0-5.0.0-migration-guide.md ================================================ ## `cohere==4` to `cohere==5` migration guide As we migrate from the handwritten, manually-maintained sdk to our auto-generated sdk, there are some breaking changes that must be accommodated during migration. These should mostly improve the developer experience but thank you for bearing with us as we make these changes. ### Installation To install the latest version of the cohere sdk `pip3 install --upgrade cohere`. ### Migrating usages #### Migrating function calls [This diff view](https://github.com/cohere-ai/cohere-python/compare/old-usage...new-usage) enumerates all usages of the old sdk and how they map to the new sdk. Some fields are no longer supported in the new sdk. #### Migrating streaming usage The `streaming: boolean` are no longer supported in the new sdk. Instead, you can replace the `chat` function with `chat_stream` and `generate` function with `generate_stream`. These will automatically inject the `streaming` parameter into the request. The following is an example usage for `chat_stream`: ```python stream = co.chat_stream( message="Tell me a short story" ) for event in stream: if event.event_type == "text-generation": print(event.text, end='') ``` ### Migrating deprecated `num_workers` Client constructor parameter The Client constructor accepts an `httpx_client` which can be configured to limit the maximum number of connections. ```python limits = httpx.Limits(max_connections=10) cohere.Client(httpx_client=httpx.Client(limits=limits)) ``` ### Removed functionality (subject to change) The following lists name the functions that are not in the new SDK and what their ongoing support status is. #### No longer supported * check_api_key * loglikelihood * batch_generate * codebook * batch_tokenize * batch_detokenize * detect_language * generate_feedback * generate_preference_feedback * create_cluster_job * get_cluster_job * list_cluster_jobs * wait_for_cluster_job * create_custom_model * wait_for_custom_model * get_custom_model * get_custom_model_by_name * get_custom_model_metrics * list_custom_models ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Cohere Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Cohere Python SDK ![](banner.png) [![version badge](https://img.shields.io/pypi/v/cohere)](https://pypi.org/project/cohere/) ![license badge](https://img.shields.io/github/license/cohere-ai/cohere-python) [![fern shield](https://img.shields.io/badge/%F0%9F%8C%BF-SDK%20generated%20by%20Fern-brightgreen)](https://github.com/fern-api/fern) The Cohere Python SDK allows access to Cohere models across many different platforms: the cohere platform, AWS (Bedrock, Sagemaker), Azure, GCP and Oracle OCI. For a full list of support and snippets, please take a look at the [SDK support docs page](https://docs.cohere.com/docs/cohere-works-everywhere). ## Documentation Cohere documentation and API reference is available [here](https://docs.cohere.com/). ## Installation ``` pip install cohere ``` ## Usage ```Python import cohere co = cohere.ClientV2() response = co.chat( model="command-r-plus-08-2024", messages=[{"role": "user", "content": "hello world!"}], ) print(response) ``` > [!TIP] > You can set a system environment variable `CO_API_KEY` to avoid writing your api key within your code, e.g. add `export CO_API_KEY=theapikeyforyouraccount` > in your ~/.zshrc or ~/.bashrc, open a new terminal, then code calling `cohere.Client()` will read this key. ## Streaming The SDK supports streaming endpoints. To take advantage of this feature for chat, use `chat_stream`. ```Python import cohere co = cohere.ClientV2() response = co.chat_stream( model="command-r-plus-08-2024", messages=[{"role": "user", "content": "hello world!"}], ) for event in response: if event.type == "content-delta": print(event.delta.message.content.text, end="") ``` ## Oracle Cloud Infrastructure (OCI) The SDK supports Oracle Cloud Infrastructure (OCI) Generative AI service. First, install the OCI SDK: ``` pip install 'cohere[oci]' ``` Then use the `OciClient` or `OciClientV2`: ```Python import cohere # Using OCI config file authentication (default: ~/.oci/config) co = cohere.OciClient( oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) response = co.embed( model="embed-english-v3.0", texts=["Hello world"], input_type="search_document", ) print(response.embeddings) ``` ### OCI Authentication Methods **1. Config File (Default)** ```Python co = cohere.OciClient( oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", # Uses ~/.oci/config with DEFAULT profile ) ``` **2. Custom Profile** ```Python co = cohere.OciClient( oci_profile="MY_PROFILE", oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` **3. Session-based Authentication (Security Token)** ```Python # Works with OCI CLI session tokens co = cohere.OciClient( oci_profile="MY_SESSION_PROFILE", # Profile with security_token_file oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` **4. Direct Credentials** ```Python co = cohere.OciClient( oci_user_id="ocid1.user.oc1...", oci_fingerprint="xx:xx:xx:...", oci_tenancy_id="ocid1.tenancy.oc1...", oci_private_key_path="~/.oci/key.pem", oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` **5. Instance Principal (for OCI Compute instances)** ```Python co = cohere.OciClient( auth_type="instance_principal", oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` ### Supported OCI APIs The OCI client supports the following Cohere APIs: - **Embed**: Full support for all embedding models - **Chat**: Full support with both V1 (`OciClient`) and V2 (`OciClientV2`) APIs - Streaming available via `chat_stream()` - Supports Command-R and Command-A model families ### OCI Model Availability and Limitations **Available on OCI On-Demand Inference:** - ✅ **Embed models**: available on OCI Generative AI - ✅ **Chat models**: available via `OciClient` (V1) and `OciClientV2` (V2) **Not Available on OCI On-Demand Inference:** - ❌ **Generate API**: OCI TEXT_GENERATION models are base models that require fine-tuning before deployment - ❌ **Rerank API**: OCI TEXT_RERANK models are base models that require fine-tuning before deployment - ❌ **Multiple Embedding Types**: OCI on-demand models only support single embedding type per request (cannot request both `float` and `int8` simultaneously) **Note**: To use Generate or Rerank models on OCI, you need to: 1. Fine-tune the base model using OCI's fine-tuning service 2. Deploy the fine-tuned model to a dedicated endpoint 3. Update your code to use the deployed model endpoint For the latest model availability, see the [OCI Generative AI documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm). ## Contributing While we value open-source contributions to this SDK, the code is generated programmatically. Additions made directly would have to be moved over to our generation code, otherwise they would be overwritten upon the next generated release. Feel free to open a PR as a proof of concept, but know that we will not be able to merge it as-is. We suggest opening an issue first to discuss with us! On the other hand, contributions to the README are always very welcome! ================================================ FILE: mypy.ini ================================================ [mypy] exclude = src/cohere/manually_maintained/cohere_aws ================================================ FILE: pyproject.toml ================================================ [project] name = "cohere" dynamic = ["version"] [tool.poetry] name = "cohere" version = "6.1.0" description = "" readme = "README.md" authors = [] keywords = [] license = "MIT" classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Programming Language :: Python :: 3.15", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", "Topic :: Software Development :: Libraries :: Python Modules", "Typing :: Typed", "License :: OSI Approved :: MIT License" ] packages = [ { include = "cohere", from = "src"} ] [tool.poetry.urls] Repository = 'https://github.com/cohere-ai/cohere-python' [tool.poetry.dependencies] python = "^3.10" aiohttp = { version = ">=3.10.0,<4", optional = true} fastavro = "^1.9.4" httpx = ">=0.21.2" httpx-aiohttp = { version = "0.1.8", optional = true} oci = { version = "^2.165.0", optional = true} pydantic = ">= 1.9.2" pydantic-core = ">=2.18.2,<2.44.0" requests = "^2.0.0" tokenizers = ">=0.15,<1" types-requests = "^2.0.0" typing_extensions = ">= 4.0.0" [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" pytest = "^8.2.0" pytest-asyncio = "^1.0.0" pytest-xdist = "^3.6.1" python-dateutil = "^2.9.0" types-python-dateutil = "^2.9.0.20240316" ruff = "==0.11.5" [tool.pytest.ini_options] testpaths = [ "tests" ] asyncio_mode = "auto" markers = [ "aiohttp: tests that require httpx_aiohttp to be installed", ] [tool.mypy] plugins = ["pydantic.mypy"] [tool.ruff] line-length = 120 [tool.ruff.lint] select = [ "E", # pycodestyle errors "F", # pyflakes "I", # isort ] ignore = [ "E402", # Module level import not at top of file "E501", # Line too long "E711", # Comparison to `None` should be `cond is not None` "E712", # Avoid equality comparisons to `True`; use `if ...:` checks "E721", # Use `is` and `is not` for type comparisons, or `isinstance()` for insinstance checks "E722", # Do not use bare `except` "E731", # Do not assign a `lambda` expression, use a `def` "F821", # Undefined name "F841" # Local variable ... is assigned to but never used ] [tool.ruff.lint.isort] section-order = ["future", "standard-library", "third-party", "first-party"] [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.poetry.extras] oci=["oci"] aiohttp=["aiohttp", "httpx-aiohttp"] ================================================ FILE: reference.md ================================================ # Reference
client.chat_stream(...) -> typing.Iterator[bytes]
#### 📝 Description
Generates a streamed text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.chat_stream( model="command-a-03-2025", message="hello!", ) ```
#### ⚙️ Parameters
**message:** `str` Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**stream:** `typing.Literal` Defaults to `false`. When `true`, the response will be a JSON stream of events. The final event will contain the complete response, and will have an `event_type` of `"stream-end"`. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**accepts:** `typing.Optional[typing.Literal]` — Pass text/event-stream to receive the streamed response as server-sent events. The default is `\n` delimited events.
**model:** `typing.Optional[str]` The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments
**preamble:** `typing.Optional[str]` When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**chat_history:** `typing.Optional[typing.List[Message]]` A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**conversation_id:** `typing.Optional[str]` An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform
**prompt_truncation:** `typing.Optional[ChatStreamRequestPromptTruncation]` Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
**connectors:** `typing.Optional[typing.List[ChatConnector]]` Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform
**search_queries_only:** `typing.Optional[bool]` Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**documents:** `typing.Optional[typing.List[ChatDocument]]` A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**citation_quality:** `typing.Optional[ChatStreamRequestCitationQuality]` Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**temperature:** `typing.Optional[float]` Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**max_tokens:** `typing.Optional[int]` The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**max_input_tokens:** `typing.Optional[int]` The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform
**k:** `typing.Optional[int]` Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**stop_sequences:** `typing.Optional[typing.List[str]]` A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**frequency_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**raw_prompting:** `typing.Optional[bool]` When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**tools:** `typing.Optional[typing.List[Tool]]` A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**tool_results:** `typing.Optional[typing.List[ToolResult]]` A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**force_single_step:** `typing.Optional[bool]` — Forces the chat to be single step. Defaults to `false`.
**response_format:** `typing.Optional[ResponseFormat]`
**safety_mode:** `typing.Optional[ChatStreamRequestSafetyMode]` Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.chat(...) -> NonStreamedChatResponse
#### 📝 Description
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.chat_stream( model="command-a-03-2025", message="Tell me about LLMs", ) ```
#### ⚙️ Parameters
**message:** `str` Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**stream:** `typing.Literal` Defaults to `false`. When `true`, the response will be a JSON stream of events. The final event will contain the complete response, and will have an `event_type` of `"stream-end"`. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**accepts:** `typing.Optional[typing.Literal]` — Pass text/event-stream to receive the streamed response as server-sent events. The default is `\n` delimited events.
**model:** `typing.Optional[str]` The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments
**preamble:** `typing.Optional[str]` When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**chat_history:** `typing.Optional[typing.List[Message]]` A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**conversation_id:** `typing.Optional[str]` An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform
**prompt_truncation:** `typing.Optional[ChatRequestPromptTruncation]` Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
**connectors:** `typing.Optional[typing.List[ChatConnector]]` Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform
**search_queries_only:** `typing.Optional[bool]` Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**documents:** `typing.Optional[typing.List[ChatDocument]]` A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**citation_quality:** `typing.Optional[ChatRequestCitationQuality]` Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**temperature:** `typing.Optional[float]` Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**max_tokens:** `typing.Optional[int]` The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**max_input_tokens:** `typing.Optional[int]` The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform
**k:** `typing.Optional[int]` Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**stop_sequences:** `typing.Optional[typing.List[str]]` A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**frequency_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**raw_prompting:** `typing.Optional[bool]` When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**tools:** `typing.Optional[typing.List[Tool]]` A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**tool_results:** `typing.Optional[typing.List[ToolResult]]` A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**force_single_step:** `typing.Optional[bool]` — Forces the chat to be single step. Defaults to `false`.
**response_format:** `typing.Optional[ResponseFormat]`
**safety_mode:** `typing.Optional[ChatRequestSafetyMode]` Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.generate_stream(...) -> typing.Iterator[bytes]
#### 📝 Description
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API. Generates realistic text conditioned on a given input.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.generate_stream( prompt="Please explain to me how LLMs work", ) ```
#### ⚙️ Parameters
**prompt:** `str` The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model.
**stream:** `typing.Literal` When `true`, the response will be a JSON stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated. The final event will contain the complete response, and will contain an `is_finished` field set to `true`. The event will also contain a `finish_reason`, which can be one of the following: - `COMPLETE` - the model sent back a finished reply - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens for its context length - `ERROR` - something went wrong when generating the reply - `ERROR_TOXIC` - the model generated a reply that was deemed toxic
**model:** `typing.Optional[str]` The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
**num_generations:** `typing.Optional[int]` — The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
**max_tokens:** `typing.Optional[int]` The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
**truncate:** `typing.Optional[GenerateStreamRequestTruncate]` One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
**temperature:** `typing.Optional[float]` A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**preset:** `typing.Optional[str]` Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
**end_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
**stop_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
**k:** `typing.Optional[int]` Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`.
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
**frequency_penalty:** `typing.Optional[float]` Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
**return_likelihoods:** `typing.Optional[GenerateStreamRequestReturnLikelihoods]` One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release.
**raw_prompting:** `typing.Optional[bool]` — When enabled, the user's prompt will be sent to the model without any pre-processing.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.generate(...) -> Generation
#### 📝 Description
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates realistic text conditioned on a given input.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.generate_stream( prompt="Please explain to me how LLMs work", ) ```
#### ⚙️ Parameters
**prompt:** `str` The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model.
**stream:** `typing.Literal` When `true`, the response will be a JSON stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated. The final event will contain the complete response, and will contain an `is_finished` field set to `true`. The event will also contain a `finish_reason`, which can be one of the following: - `COMPLETE` - the model sent back a finished reply - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens for its context length - `ERROR` - something went wrong when generating the reply - `ERROR_TOXIC` - the model generated a reply that was deemed toxic
**model:** `typing.Optional[str]` The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
**num_generations:** `typing.Optional[int]` — The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
**max_tokens:** `typing.Optional[int]` The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
**truncate:** `typing.Optional[GenerateRequestTruncate]` One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
**temperature:** `typing.Optional[float]` A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
**preset:** `typing.Optional[str]` Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
**end_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
**stop_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
**k:** `typing.Optional[int]` Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`.
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
**frequency_penalty:** `typing.Optional[float]` Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
**return_likelihoods:** `typing.Optional[GenerateRequestReturnLikelihoods]` One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release.
**raw_prompting:** `typing.Optional[bool]` — When enabled, the user's prompt will be sent to the model without any pre-processing.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed(...) -> EmbedResponse
#### 📝 Description
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents. Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.embed( texts=[ "hello", "goodbye" ], model="embed-v4.0", input_type="classification", ) ```
#### ⚙️ Parameters
**texts:** `typing.Optional[typing.List[str]]` — An array of strings for the model to embed. Maximum number of texts per call is `96`.
**images:** `typing.Optional[typing.List[str]]` An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Images are only supported with Embed v3.0 and newer models.
**model:** `typing.Optional[str]` — ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
**input_type:** `typing.Optional[EmbedInputType]`
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]` Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
**truncate:** `typing.Optional[EmbedRequestTruncate]` One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.rerank(...) -> RerankResponse
#### 📝 Description
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.rerank( documents=[ { "text": "Carson City is the capital city of the American state of Nevada." }, { "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }, { "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages." }, { "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }, { "text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." } ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) ```
#### ⚙️ Parameters
**query:** `str` — The search query
**documents:** `typing.List[RerankRequestDocumentsItem]` A list of document objects or strings to rerank. If a document is provided the text fields is required and all other fields will be preserved in the response. The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. We recommend a maximum of 1,000 documents for optimal endpoint performance.
**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
**top_n:** `typing.Optional[int]` — The number of most relevant documents or indices to return, defaults to the length of the documents
**rank_fields:** `typing.Optional[typing.List[str]]` — If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
**return_documents:** `typing.Optional[bool]` - If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request. - If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
**max_chunks_per_doc:** `typing.Optional[int]` — The maximum number of chunks to produce internally from a document
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.classify(...) -> ClassifyResponse
#### 📝 Description
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
#### 🔌 Usage
```python from cohere import Client, ClassifyExample from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.classify( examples=[ ClassifyExample( text="Dermatologists don\'t like her!", label="Spam", ), ClassifyExample( text="\'Hello, open to this?\'", label="Spam", ), ClassifyExample( text="I need help please wire me $1000 right now", label="Spam", ), ClassifyExample( text="Nice to know you ;)", label="Spam", ), ClassifyExample( text="Please help me?", label="Spam", ), ClassifyExample( text="Your parcel will be delivered today", label="Not spam", ), ClassifyExample( text="Review changes to our Terms and Conditions", label="Not spam", ), ClassifyExample( text="Weekly sync notes", label="Not spam", ), ClassifyExample( text="\'Re: Follow up from today\'s meeting\'", label="Not spam", ), ClassifyExample( text="Pre-read for tomorrow", label="Not spam", ) ], inputs=[ "Confirm your email address", "hey i need u to send some $" ], model="YOUR-FINE-TUNED-MODEL-ID", ) ```
#### ⚙️ Parameters
**inputs:** `typing.List[str]` A list of up to 96 texts to be classified. Each one must be a non-empty string. There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models). Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
**examples:** `typing.Optional[typing.List[ClassifyExample]]` An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
**model:** `typing.Optional[str]` — ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
**preset:** `typing.Optional[str]` — The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
**truncate:** `typing.Optional[ClassifyRequestTruncate]` One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.summarize(...) -> SummarizeResponse
#### 📝 Description
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates a summary in English for a given text.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.summarize( text="Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name \"ice cream\" varies from one country to another. In some countries, such as the United States, \"ice cream\" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled \"frozen dairy dessert\" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.", ) ```
#### ⚙️ Parameters
**text:** `str` — The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
**length:** `typing.Optional[SummarizeRequestLength]` — One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
**format:** `typing.Optional[SummarizeRequestFormat]` — One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
**model:** `typing.Optional[str]` — The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
**extractiveness:** `typing.Optional[SummarizeRequestExtractiveness]` — One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
**temperature:** `typing.Optional[float]` — Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
**additional_command:** `typing.Optional[str]` — A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.tokenize(...) -> TokenizeResponse
#### 📝 Description
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.tokenize( text="tokenize me! :D", model="command", ) ```
#### ⚙️ Parameters
**text:** `str` — The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
**model:** `str` — The input will be tokenized by the tokenizer that is used by this model.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.detokenize(...) -> DetokenizeResponse
#### 📝 Description
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.detokenize( tokens=[ 10002, 2261, 2012, 8, 2792, 43 ], model="command", ) ```
#### ⚙️ Parameters
**tokens:** `typing.List[int]` — The list of tokens to be detokenized.
**model:** `str` — An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.check_api_key() -> CheckApiKeyResponse
#### 📝 Description
Checks that the api key in the Authorization header is valid and active
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.check_api_key() ```
#### ⚙️ Parameters
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## V2
client.v2.chat_stream(...) -> typing.Iterator[bytes]
#### 📝 Description
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
#### 🔌 Usage
```python from cohere import Client, ChatMessageV2_User from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.v2.chat_stream( model="command-a-03-2025", messages=[ ChatMessageV2_User( content="Tell me about LLMs", ) ], ) ```
#### ⚙️ Parameters
**stream:** `typing.Literal` Defaults to `false`. When `true`, the response will be a SSE stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
**model:** `str` — The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
**messages:** `ChatMessages`
**tools:** `typing.Optional[typing.List[ToolV2]]` A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
**strict_tools:** `typing.Optional[bool]` When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process.
**documents:** `typing.Optional[typing.List[V2ChatStreamRequestDocumentsItem]]` — A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
**citation_options:** `typing.Optional[CitationOptions]`
**response_format:** `typing.Optional[ResponseFormatV2]`
**safety_mode:** `typing.Optional[V2ChatStreamRequestSafetyMode]` Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
**max_tokens:** `typing.Optional[int]` The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
**stop_sequences:** `typing.Optional[typing.List[str]]` — A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
**temperature:** `typing.Optional[float]` Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter.
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed.
**frequency_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
**k:** `typing.Optional[int]` Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`.
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
**logprobs:** `typing.Optional[bool]` — Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
**tool_choice:** `typing.Optional[V2ChatStreamRequestToolChoice]` Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
**thinking:** `typing.Optional[Thinking]`
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.chat(...) -> V2ChatResponse
#### 📝 Description
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
#### 🔌 Usage
```python from cohere import Client, ChatMessageV2_User from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.v2.chat_stream( model="command-a-03-2025", messages=[ ChatMessageV2_User( content="Tell me about LLMs", ) ], ) ```
#### ⚙️ Parameters
**stream:** `typing.Literal` Defaults to `false`. When `true`, the response will be a SSE stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
**model:** `str` — The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
**messages:** `ChatMessages`
**tools:** `typing.Optional[typing.List[ToolV2]]` A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
**strict_tools:** `typing.Optional[bool]` When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process.
**documents:** `typing.Optional[typing.List[V2ChatRequestDocumentsItem]]` — A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
**citation_options:** `typing.Optional[CitationOptions]`
**response_format:** `typing.Optional[ResponseFormatV2]`
**safety_mode:** `typing.Optional[V2ChatRequestSafetyMode]` Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
**max_tokens:** `typing.Optional[int]` The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
**stop_sequences:** `typing.Optional[typing.List[str]]` — A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
**temperature:** `typing.Optional[float]` Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter.
**seed:** `typing.Optional[int]` If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed.
**frequency_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
**presence_penalty:** `typing.Optional[float]` Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
**k:** `typing.Optional[int]` Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`.
**p:** `typing.Optional[float]` Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
**logprobs:** `typing.Optional[bool]` — Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
**tool_choice:** `typing.Optional[V2ChatRequestToolChoice]` Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
**thinking:** `typing.Optional[Thinking]`
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.embed(...) -> EmbedByTypeResponse
#### 📝 Description
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.v2.embed( texts=[ "hello", "goodbye" ], model="embed-v4.0", input_type="classification", embedding_types=[ "float" ], ) ```
#### ⚙️ Parameters
**model:** `str` — ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
**input_type:** `EmbedInputType`
**texts:** `typing.Optional[typing.List[str]]` — An array of strings for the model to embed. Maximum number of texts per call is `96`.
**images:** `typing.Optional[typing.List[str]]` An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Image embeddings are supported with Embed v3.0 and newer models.
**inputs:** `typing.Optional[typing.List[EmbedInput]]` — An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
**max_tokens:** `typing.Optional[int]` — The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
**output_dimension:** `typing.Optional[int]` The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models. Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]` Specifies the types of embeddings you want to get back. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
**truncate:** `typing.Optional[V2EmbedRequestTruncate]` One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.rerank(...) -> V2RerankResponse
#### 📝 Description
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.v2.rerank( documents=[ "Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) ```
#### ⚙️ Parameters
**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
**query:** `str` — The search query
**documents:** `typing.List[str]` A list of texts that will be compared to the `query`. For optimal performance we recommend against sending more than 1,000 documents in a single request. **Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`. **Note**: structured data should be formatted as YAML strings for best performance.
**top_n:** `typing.Optional[int]` — Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
**max_tokens_per_doc:** `typing.Optional[int]` — Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Batches
client.batches.list(...) -> ListBatchesResponse
#### 📝 Description
List the batches for the current user
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.batches.list( page_size=1, page_token="page_token", order_by="order_by", ) ```
#### ⚙️ Parameters
**page_size:** `typing.Optional[int]` The maximum number of batches to return. The service may return fewer than this value. If unspecified, at most 50 batches will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000.
**page_token:** `typing.Optional[str]` A page token, received from a previous `ListBatches` call. Provide this to retrieve the subsequent page.
**order_by:** `typing.Optional[str]` Batches can be ordered by creation time or last updated time. Use `created_at` for creation time or `updated_at` for last updated time.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.create(...) -> CreateBatchResponse
#### 📝 Description
Creates and executes a batch from an uploaded dataset of requests
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment from cohere.batches import Batch client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.batches.create( request=Batch( name="name", input_dataset_id="input_dataset_id", model="model", ), ) ```
#### ⚙️ Parameters
**request:** `Batch`
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.retrieve(...) -> GetBatchResponse
#### 📝 Description
Retrieves a batch
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.batches.retrieve( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The batch ID.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.cancel(...) -> CancelBatchResponse
#### 📝 Description
Cancels an in-progress batch
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.batches.cancel( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The batch ID.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## EmbedJobs
client.embed_jobs.list() -> ListEmbedJobResponse
#### 📝 Description
The list embed job endpoint allows users to view all embed jobs history for that specific user.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.embed_jobs.list() ```
#### ⚙️ Parameters
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.create(...) -> CreateEmbedJobResponse
#### 📝 Description
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.embed_jobs.create( model="model", dataset_id="dataset_id", input_type="search_document", ) ```
#### ⚙️ Parameters
**model:** `str` ID of the embedding model. Available models and corresponding embedding dimensions: - `embed-english-v3.0` : 1024 - `embed-multilingual-v3.0` : 1024 - `embed-english-light-v3.0` : 384 - `embed-multilingual-light-v3.0` : 384
**dataset_id:** `str` — ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
**input_type:** `EmbedInputType`
**name:** `typing.Optional[str]` — The name of the embed job.
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]` Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions. * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
**truncate:** `typing.Optional[CreateEmbedJobRequestTruncate]` One of `START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.get(...) -> EmbedJob
#### 📝 Description
This API retrieves the details about an embed job started by the same user.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.embed_jobs.get( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the embed job to retrieve.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.cancel(...)
#### 📝 Description
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.embed_jobs.cancel( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the embed job to cancel.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Datasets
client.datasets.list(...) -> DatasetsListResponse
#### 📝 Description
List datasets that have been created.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment import datetime client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.datasets.list( dataset_type="datasetType", before=datetime.datetime.fromisoformat("2024-01-15T09:30:00+00:00"), after=datetime.datetime.fromisoformat("2024-01-15T09:30:00+00:00"), limit=1.1, offset=1.1, validation_status="unknown", ) ```
#### ⚙️ Parameters
**dataset_type:** `typing.Optional[str]` — optional filter by dataset type
**before:** `typing.Optional[datetime.datetime]` — optional filter before a date
**after:** `typing.Optional[datetime.datetime]` — optional filter after a date
**limit:** `typing.Optional[float]` — optional limit to number of results
**offset:** `typing.Optional[float]` — optional offset to start of results
**validation_status:** `typing.Optional[DatasetValidationStatus]` — optional filter by validation status
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.create(...) -> DatasetsCreateResponse
#### 📝 Description
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.datasets.create( name="name", type="embed-input", keep_original_file=True, skip_malformed_input=True, text_separator="text_separator", csv_delimiter="csv_delimiter", data="example_data", eval_data="example_eval_data", ) ```
#### ⚙️ Parameters
**name:** `str` — The name of the uploaded dataset.
**type:** `DatasetType` — The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
**data:** `core.File` — The file to upload
**keep_original_file:** `typing.Optional[bool]` — Indicates if the original file should be stored.
**skip_malformed_input:** `typing.Optional[bool]` — Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
**keep_fields:** `typing.Optional[typing.Union[str, typing.Sequence[str]]]` — List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
**optional_fields:** `typing.Optional[typing.Union[str, typing.Sequence[str]]]` — List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
**text_separator:** `typing.Optional[str]` — Raw .txt uploads will be split into entries using the text_separator value.
**csv_delimiter:** `typing.Optional[str]` — The delimiter used for .csv uploads.
**eval_data:** `typing.Optional[core.File]` — An optional evaluation file to upload
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.get_usage() -> DatasetsGetUsageResponse
#### 📝 Description
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.datasets.get_usage() ```
#### ⚙️ Parameters
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.get(...) -> DatasetsGetResponse
#### 📝 Description
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.datasets.get( id="id", ) ```
#### ⚙️ Parameters
**id:** `str`
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.delete(...) -> typing.Dict[str, typing.Any]
#### 📝 Description
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.datasets.delete( id="id", ) ```
#### ⚙️ Parameters
**id:** `str`
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Connectors
client.connectors.list(...) -> ListConnectorsResponse
#### 📝 Description
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.list( limit=1.1, offset=1.1, ) ```
#### ⚙️ Parameters
**limit:** `typing.Optional[float]` — Maximum number of connectors to return [0, 100].
**offset:** `typing.Optional[float]` — Number of connectors to skip before returning results [0, inf].
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.create(...) -> CreateConnectorResponse
#### 📝 Description
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.create( name="name", url="url", ) ```
#### ⚙️ Parameters
**name:** `str` — A human-readable name for the connector.
**url:** `str` — The URL of the connector that will be used to search for documents.
**description:** `typing.Optional[str]` — A description of the connector.
**excludes:** `typing.Optional[typing.List[str]]` — A list of fields to exclude from the prompt (fields remain in the document).
**oauth:** `typing.Optional[CreateConnectorOAuth]` — The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
**active:** `typing.Optional[bool]` — Whether the connector is active or not.
**continue_on_failure:** `typing.Optional[bool]` — Whether a chat request should continue or not if the request to this connector fails.
**service_auth:** `typing.Optional[CreateConnectorServiceAuth]` — The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.get(...) -> GetConnectorResponse
#### 📝 Description
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.get( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the connector to retrieve.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.delete(...) -> DeleteConnectorResponse
#### 📝 Description
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.delete( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the connector to delete.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.update(...) -> UpdateConnectorResponse
#### 📝 Description
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.update( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the connector to update.
**name:** `typing.Optional[str]` — A human-readable name for the connector.
**url:** `typing.Optional[str]` — The URL of the connector that will be used to search for documents.
**excludes:** `typing.Optional[typing.List[str]]` — A list of fields to exclude from the prompt (fields remain in the document).
**oauth:** `typing.Optional[CreateConnectorOAuth]` — The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
**active:** `typing.Optional[bool]`
**continue_on_failure:** `typing.Optional[bool]`
**service_auth:** `typing.Optional[CreateConnectorServiceAuth]` — The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.o_auth_authorize(...) -> OAuthAuthorizeResponse
#### 📝 Description
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.connectors.o_auth_authorize( id="id", after_token_redirect="after_token_redirect", ) ```
#### ⚙️ Parameters
**id:** `str` — The ID of the connector to authorize.
**after_token_redirect:** `typing.Optional[str]` — The URL to redirect to after the connector has been authorized.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Models
client.models.get(...) -> GetModelResponse
#### 📝 Description
Returns the details of a model, provided its name.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.models.get( model="command-a-03-2025", ) ```
#### ⚙️ Parameters
**model:** `str`
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.models.list(...) -> ListModelsResponse
#### 📝 Description
Returns a list of models available for use.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.models.list( page_size=1.1, page_token="page_token", endpoint="chat", default_only=True, ) ```
#### ⚙️ Parameters
**page_size:** `typing.Optional[float]` Maximum number of models to include in a page Defaults to `20`, min value of `1`, max value of `1000`.
**page_token:** `typing.Optional[str]` — Page token provided in the `next_page_token` field of a previous response.
**endpoint:** `typing.Optional[CompatibleEndpoint]` — When provided, filters the list of models to only those that are compatible with the specified endpoint.
**default_only:** `typing.Optional[bool]` — When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## /finetuning
client.finetuning.list_finetuned_models(...) -> ListFinetunedModelsResponse
#### 📝 Description
Returns a list of fine-tuned models that the user has access to.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.list_finetuned_models( page_size=1, page_token="page_token", order_by="order_by", ) ```
#### ⚙️ Parameters
**page_size:** `typing.Optional[int]` Maximum number of results to be returned by the server. If 0, defaults to 50.
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
**order_by:** `typing.Optional[str]` Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default)
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.create_finetuned_model(...) -> CreateFinetunedModelResponse
#### 📝 Description
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment from cohere.finetuning.finetuning import FinetunedModel, Settings, BaseModel client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.create_finetuned_model( request=FinetunedModel( name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ), ) ```
#### ⚙️ Parameters
**request:** `FinetunedModel`
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.get_finetuned_model(...) -> GetFinetunedModelResponse
#### 📝 Description
Retrieve a fine-tuned model by its ID.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.get_finetuned_model( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The fine-tuned model ID.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.delete_finetuned_model(...) -> DeleteFinetunedModelResponse
#### 📝 Description
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use. This operation is irreversible.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.delete_finetuned_model( id="id", ) ```
#### ⚙️ Parameters
**id:** `str` — The fine-tuned model ID.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.update_finetuned_model(...) -> UpdateFinetunedModelResponse
#### 📝 Description
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment from cohere.finetuning.finetuning import Settings, BaseModel client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.update_finetuned_model( id="id", name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ) ```
#### ⚙️ Parameters
**id:** `str` — FinetunedModel ID.
**name:** `str` — FinetunedModel name (e.g. `foobar`).
**settings:** `Settings` — FinetunedModel settings such as dataset, hyperparameters...
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.list_events(...) -> ListEventsResponse
#### 📝 Description
Returns a list of events that occurred during the life-cycle of the fine-tuned model. The events are ordered by creation time, with the most recent event first. The list can be paginated using `page_size` and `page_token` parameters.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.list_events( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", order_by="order_by", ) ```
#### ⚙️ Parameters
**finetuned_model_id:** `str` — The parent fine-tuned model ID.
**page_size:** `typing.Optional[int]` Maximum number of results to be returned by the server. If 0, defaults to 50.
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
**order_by:** `typing.Optional[str]` Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default)
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.list_training_step_metrics(...) -> ListTrainingStepMetricsResponse
#### 📝 Description
Returns a list of metrics measured during the training of a fine-tuned model. The metrics are ordered by step number, with the most recent step first. The list can be paginated using `page_size` and `page_token` parameters.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.finetuning.list_training_step_metrics( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", ) ```
#### ⚙️ Parameters
**finetuned_model_id:** `str` — The parent fine-tuned model ID.
**page_size:** `typing.Optional[int]` Maximum number of results to be returned by the server. If 0, defaults to 50.
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Audio Transcriptions
client.audio.transcriptions.create(...) -> AudioTranscriptionsCreateResponse
#### 📝 Description
Transcribe an audio file.
#### 🔌 Usage
```python from cohere import Client from cohere.environment import ClientEnvironment client = Client( token="", environment=ClientEnvironment.PRODUCTION, ) client.audio.transcriptions.create( file="example_file", model="model", language="language", ) ```
#### ⚙️ Parameters
**model:** `str` — ID of the model to use.
**language:** `str` — The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
**file:** `core.File` — The audio file object to transcribe. Supported file extensions are flac, mp3, mpeg, mpga, ogg, and wav.
**temperature:** `typing.Optional[float]` — The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
================================================ FILE: requirements.txt ================================================ fastavro==1.9.4 httpx>=0.21.2 pydantic>= 1.9.2 pydantic-core>=2.18.2,<2.44.0 requests==2.0.0 tokenizers>=0.15,<1 types-requests==2.0.0 typing_extensions>= 4.0.0 ================================================ FILE: src/cohere/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import ( ApiMeta, ApiMetaApiVersion, ApiMetaBilledUnits, ApiMetaTokens, AssistantChatMessageV2, AssistantMessage, AssistantMessageResponse, AssistantMessageResponseContentItem, AssistantMessageV2Content, AssistantMessageV2ContentOneItem, AuthTokenType, ChatCitation, ChatCitationGenerationEvent, ChatCitationType, ChatConnector, ChatContentDeltaEvent, ChatContentDeltaEventDelta, ChatContentDeltaEventDeltaMessage, ChatContentDeltaEventDeltaMessageContent, ChatContentEndEvent, ChatContentStartEvent, ChatContentStartEventDelta, ChatContentStartEventDeltaMessage, ChatContentStartEventDeltaMessageContent, ChatContentStartEventDeltaMessageContentType, ChatDataMetrics, ChatDebugEvent, ChatDocument, ChatDocumentSource, ChatFinishReason, ChatMessage, ChatMessageEndEvent, ChatMessageEndEventDelta, ChatMessageStartEvent, ChatMessageStartEventDelta, ChatMessageStartEventDeltaMessage, ChatMessageV2, ChatMessages, ChatRequestCitationQuality, ChatRequestPromptTruncation, ChatRequestSafetyMode, ChatSearchQueriesGenerationEvent, ChatSearchQuery, ChatSearchResult, ChatSearchResultConnector, ChatSearchResultsEvent, ChatStreamEndEvent, ChatStreamEndEventFinishReason, ChatStreamEvent, ChatStreamEventType, ChatStreamRequestCitationQuality, ChatStreamRequestPromptTruncation, ChatStreamRequestSafetyMode, ChatStreamStartEvent, ChatTextContent, ChatTextGenerationEvent, ChatTextResponseFormat, ChatTextResponseFormatV2, ChatThinkingContent, ChatToolCallDeltaEvent, ChatToolCallDeltaEventDelta, ChatToolCallDeltaEventDeltaMessage, ChatToolCallDeltaEventDeltaMessageToolCalls, ChatToolCallDeltaEventDeltaMessageToolCallsFunction, ChatToolCallEndEvent, ChatToolCallStartEvent, ChatToolCallStartEventDelta, ChatToolCallStartEventDeltaMessage, ChatToolCallsChunkEvent, ChatToolCallsGenerationEvent, ChatToolMessage, ChatToolPlanDeltaEvent, ChatToolPlanDeltaEventDelta, ChatToolPlanDeltaEventDeltaMessage, ChatToolSource, ChatbotMessage, CheckApiKeyResponse, Citation, CitationEndEvent, CitationGenerationStreamedChatResponse, CitationOptions, CitationOptionsMode, CitationStartEvent, CitationStartEventDelta, CitationStartEventDeltaMessage, CitationType, ClassifyDataMetrics, ClassifyExample, ClassifyRequestTruncate, ClassifyResponse, ClassifyResponseClassificationsItem, ClassifyResponseClassificationsItemClassificationType, ClassifyResponseClassificationsItemLabelsValue, CompatibleEndpoint, Connector, ConnectorAuthStatus, ConnectorOAuth, Content, CreateConnectorOAuth, CreateConnectorResponse, CreateConnectorServiceAuth, CreateEmbedJobResponse, Dataset, DatasetPart, DatasetType, DatasetValidationStatus, DebugStreamedChatResponse, DeleteConnectorResponse, DetokenizeResponse, Document, DocumentContent, DocumentSource, DocumentToolContent, EmbedByTypeResponse, EmbedByTypeResponseEmbeddings, EmbedByTypeResponseResponseType, EmbedContent, EmbedFloatsResponse, EmbedImage, EmbedImageUrl, EmbedInput, EmbedInputType, EmbedJob, EmbedJobStatus, EmbedJobTruncate, EmbedRequestTruncate, EmbedResponse, EmbedText, EmbeddingType, EmbeddingsByTypeEmbedResponse, EmbeddingsFloatsEmbedResponse, FinetuneDatasetMetrics, FinishReason, GenerateRequestReturnLikelihoods, GenerateRequestTruncate, GenerateStreamEnd, GenerateStreamEndResponse, GenerateStreamError, GenerateStreamEvent, GenerateStreamRequestReturnLikelihoods, GenerateStreamRequestTruncate, GenerateStreamText, GenerateStreamedResponse, Generation, GetConnectorResponse, GetModelResponse, GetModelResponseSamplingDefaults, Image, ImageContent, ImageUrl, ImageUrlContent, ImageUrlDetail, ImageUrlEmbedContent, JsonObjectResponseFormat, JsonObjectResponseFormatV2, JsonResponseFormat, JsonResponseFormatV2, LabelMetric, ListConnectorsResponse, ListEmbedJobResponse, ListModelsResponse, LogprobItem, Message, Metrics, NonStreamedChatResponse, OAuthAuthorizeResponse, ParseInfo, RerankDocument, RerankRequestDocumentsItem, RerankResponse, RerankResponseResultsItem, RerankResponseResultsItemDocument, RerankerDataMetrics, ResponseFormat, ResponseFormatV2, SearchQueriesGenerationStreamedChatResponse, SearchResultsStreamedChatResponse, SingleGeneration, SingleGenerationInStream, SingleGenerationTokenLikelihoodsItem, Source, StreamEndGenerateStreamedResponse, StreamEndStreamedChatResponse, StreamErrorGenerateStreamedResponse, StreamStartStreamedChatResponse, StreamedChatResponse, SummarizeRequestExtractiveness, SummarizeRequestFormat, SummarizeRequestLength, SummarizeResponse, SystemChatMessageV2, SystemMessage, SystemMessageV2, SystemMessageV2Content, SystemMessageV2ContentOneItem, TextAssistantMessageResponseContentItem, TextAssistantMessageV2ContentOneItem, TextContent, TextEmbedContent, TextGenerationGenerateStreamedResponse, TextGenerationStreamedChatResponse, TextResponseFormat, TextResponseFormatV2, TextSystemMessageV2ContentOneItem, TextToolContent, Thinking, ThinkingAssistantMessageResponseContentItem, ThinkingAssistantMessageV2ContentOneItem, ThinkingType, TokenizeResponse, Tool, ToolCall, ToolCallDelta, ToolCallV2, ToolCallV2Function, ToolCallsChunkStreamedChatResponse, ToolCallsGenerationStreamedChatResponse, ToolChatMessageV2, ToolContent, ToolMessage, ToolMessageV2, ToolMessageV2Content, ToolParameterDefinitionsValue, ToolResult, ToolSource, ToolV2, ToolV2Function, UpdateConnectorResponse, Usage, UsageBilledUnits, UsageTokens, UserChatMessageV2, UserMessage, UserMessageV2, UserMessageV2Content, ) from .errors import ( BadRequestError, ClientClosedRequestError, ForbiddenError, GatewayTimeoutError, InternalServerError, InvalidTokenError, NotFoundError, NotImplementedError, ServiceUnavailableError, TooManyRequestsError, UnauthorizedError, UnprocessableEntityError, ) from . import audio, batches, connectors, datasets, embed_jobs, finetuning, models, v2 from ._default_clients import DefaultAioHttpClient, DefaultAsyncHttpxClient from .aliases import ( ChatResponse, ContentDeltaStreamedChatResponseV2, ContentEndStreamedChatResponseV2, ContentStartStreamedChatResponseV2, MessageEndStreamedChatResponseV2, MessageStartStreamedChatResponseV2, StreamedChatResponseV2, ToolCallDeltaStreamedChatResponseV2, ToolCallEndStreamedChatResponseV2, ToolCallStartStreamedChatResponseV2, ) from .aws_client import AwsClient from .batches import ( Batch, BatchStatus, CancelBatchResponse, CreateBatchResponse, GetBatchResponse, ListBatchesResponse, ) from .bedrock_client import BedrockClient, BedrockClientV2 from .client import AsyncClient, Client from .client_v2 import AsyncClientV2, ClientV2 from .datasets import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse from .embed_jobs import CreateEmbedJobRequestTruncate from .environment import ClientEnvironment from .oci_client import OciClient, OciClientV2 from .sagemaker_client import SagemakerClient, SagemakerClientV2 from .v2 import ( CitationEndV2ChatStreamResponse, CitationStartV2ChatStreamResponse, ContentDeltaV2ChatStreamResponse, ContentEndV2ChatStreamResponse, ContentStartV2ChatStreamResponse, DebugV2ChatStreamResponse, MessageEndV2ChatStreamResponse, MessageStartV2ChatStreamResponse, ToolCallDeltaV2ChatStreamResponse, ToolCallEndV2ChatStreamResponse, ToolCallStartV2ChatStreamResponse, ToolPlanDeltaV2ChatStreamResponse, V2ChatRequestDocumentsItem, V2ChatRequestSafetyMode, V2ChatRequestToolChoice, V2ChatResponse, V2ChatStreamRequestDocumentsItem, V2ChatStreamRequestSafetyMode, V2ChatStreamRequestToolChoice, V2ChatStreamResponse, V2EmbedRequestTruncate, V2RerankResponse, V2RerankResponseResultsItem, ) from .version import __version__ _dynamic_imports: typing.Dict[str, str] = { "ApiMeta": ".types", "ApiMetaApiVersion": ".types", "ApiMetaBilledUnits": ".types", "ApiMetaTokens": ".types", "AssistantChatMessageV2": ".types", "AssistantMessage": ".types", "AssistantMessageResponse": ".types", "AssistantMessageResponseContentItem": ".types", "AssistantMessageV2Content": ".types", "AssistantMessageV2ContentOneItem": ".types", "AsyncClient": ".client", "AsyncClientV2": ".client_v2", "AuthTokenType": ".types", "AwsClient": ".aws_client", "BadRequestError": ".errors", "Batch": ".batches", "BatchStatus": ".batches", "BedrockClient": ".bedrock_client", "BedrockClientV2": ".bedrock_client", "CancelBatchResponse": ".batches", "ChatCitation": ".types", "ChatCitationGenerationEvent": ".types", "ChatCitationType": ".types", "ChatConnector": ".types", "ChatContentDeltaEvent": ".types", "ChatContentDeltaEventDelta": ".types", "ChatContentDeltaEventDeltaMessage": ".types", "ChatContentDeltaEventDeltaMessageContent": ".types", "ChatContentEndEvent": ".types", "ChatContentStartEvent": ".types", "ChatContentStartEventDelta": ".types", "ChatContentStartEventDeltaMessage": ".types", "ChatContentStartEventDeltaMessageContent": ".types", "ChatContentStartEventDeltaMessageContentType": ".types", "ChatDataMetrics": ".types", "ChatDebugEvent": ".types", "ChatDocument": ".types", "ChatDocumentSource": ".types", "ChatFinishReason": ".types", "ChatMessage": ".types", "ChatMessageEndEvent": ".types", "ChatMessageEndEventDelta": ".types", "ChatMessageStartEvent": ".types", "ChatMessageStartEventDelta": ".types", "ChatMessageStartEventDeltaMessage": ".types", "ChatMessageV2": ".types", "ChatMessages": ".types", "ChatRequestCitationQuality": ".types", "ChatRequestPromptTruncation": ".types", "ChatRequestSafetyMode": ".types", "ChatResponse": ".aliases", "ChatSearchQueriesGenerationEvent": ".types", "ChatSearchQuery": ".types", "ChatSearchResult": ".types", "ChatSearchResultConnector": ".types", "ChatSearchResultsEvent": ".types", "ChatStreamEndEvent": ".types", "ChatStreamEndEventFinishReason": ".types", "ChatStreamEvent": ".types", "ChatStreamEventType": ".types", "ChatStreamRequestCitationQuality": ".types", "ChatStreamRequestPromptTruncation": ".types", "ChatStreamRequestSafetyMode": ".types", "ChatStreamStartEvent": ".types", "ChatTextContent": ".types", "ChatTextGenerationEvent": ".types", "ChatTextResponseFormat": ".types", "ChatTextResponseFormatV2": ".types", "ChatThinkingContent": ".types", "ChatToolCallDeltaEvent": ".types", "ChatToolCallDeltaEventDelta": ".types", "ChatToolCallDeltaEventDeltaMessage": ".types", "ChatToolCallDeltaEventDeltaMessageToolCalls": ".types", "ChatToolCallDeltaEventDeltaMessageToolCallsFunction": ".types", "ChatToolCallEndEvent": ".types", "ChatToolCallStartEvent": ".types", "ChatToolCallStartEventDelta": ".types", "ChatToolCallStartEventDeltaMessage": ".types", "ChatToolCallsChunkEvent": ".types", "ChatToolCallsGenerationEvent": ".types", "ChatToolMessage": ".types", "ChatToolPlanDeltaEvent": ".types", "ChatToolPlanDeltaEventDelta": ".types", "ChatToolPlanDeltaEventDeltaMessage": ".types", "ChatToolSource": ".types", "ChatbotMessage": ".types", "CheckApiKeyResponse": ".types", "Citation": ".types", "CitationEndEvent": ".types", "CitationEndV2ChatStreamResponse": ".v2", "CitationGenerationStreamedChatResponse": ".types", "CitationOptions": ".types", "CitationOptionsMode": ".types", "CitationStartEvent": ".types", "CitationStartEventDelta": ".types", "CitationStartEventDeltaMessage": ".types", "CitationStartV2ChatStreamResponse": ".v2", "CitationType": ".types", "ClassifyDataMetrics": ".types", "ClassifyExample": ".types", "ClassifyRequestTruncate": ".types", "ClassifyResponse": ".types", "ClassifyResponseClassificationsItem": ".types", "ClassifyResponseClassificationsItemClassificationType": ".types", "ClassifyResponseClassificationsItemLabelsValue": ".types", "Client": ".client", "ClientClosedRequestError": ".errors", "ClientEnvironment": ".environment", "ClientV2": ".client_v2", "CompatibleEndpoint": ".types", "Connector": ".types", "ConnectorAuthStatus": ".types", "ConnectorOAuth": ".types", "Content": ".types", "ContentDeltaStreamedChatResponseV2": ".aliases", "ContentDeltaV2ChatStreamResponse": ".v2", "ContentEndStreamedChatResponseV2": ".aliases", "ContentEndV2ChatStreamResponse": ".v2", "ContentStartStreamedChatResponseV2": ".aliases", "ContentStartV2ChatStreamResponse": ".v2", "CreateBatchResponse": ".batches", "CreateConnectorOAuth": ".types", "CreateConnectorResponse": ".types", "CreateConnectorServiceAuth": ".types", "CreateEmbedJobRequestTruncate": ".embed_jobs", "CreateEmbedJobResponse": ".types", "Dataset": ".types", "DatasetPart": ".types", "DatasetType": ".types", "DatasetValidationStatus": ".types", "DatasetsCreateResponse": ".datasets", "DatasetsGetResponse": ".datasets", "DatasetsGetUsageResponse": ".datasets", "DatasetsListResponse": ".datasets", "DebugStreamedChatResponse": ".types", "DebugV2ChatStreamResponse": ".v2", "DefaultAioHttpClient": "._default_clients", "DefaultAsyncHttpxClient": "._default_clients", "DeleteConnectorResponse": ".types", "DetokenizeResponse": ".types", "Document": ".types", "DocumentContent": ".types", "DocumentSource": ".types", "DocumentToolContent": ".types", "EmbedByTypeResponse": ".types", "EmbedByTypeResponseEmbeddings": ".types", "EmbedByTypeResponseResponseType": ".types", "EmbedContent": ".types", "EmbedFloatsResponse": ".types", "EmbedImage": ".types", "EmbedImageUrl": ".types", "EmbedInput": ".types", "EmbedInputType": ".types", "EmbedJob": ".types", "EmbedJobStatus": ".types", "EmbedJobTruncate": ".types", "EmbedRequestTruncate": ".types", "EmbedResponse": ".types", "EmbedText": ".types", "EmbeddingType": ".types", "EmbeddingsByTypeEmbedResponse": ".types", "EmbeddingsFloatsEmbedResponse": ".types", "FinetuneDatasetMetrics": ".types", "FinishReason": ".types", "ForbiddenError": ".errors", "GatewayTimeoutError": ".errors", "GenerateRequestReturnLikelihoods": ".types", "GenerateRequestTruncate": ".types", "GenerateStreamEnd": ".types", "GenerateStreamEndResponse": ".types", "GenerateStreamError": ".types", "GenerateStreamEvent": ".types", "GenerateStreamRequestReturnLikelihoods": ".types", "GenerateStreamRequestTruncate": ".types", "GenerateStreamText": ".types", "GenerateStreamedResponse": ".types", "Generation": ".types", "GetBatchResponse": ".batches", "GetConnectorResponse": ".types", "GetModelResponse": ".types", "GetModelResponseSamplingDefaults": ".types", "Image": ".types", "ImageContent": ".types", "ImageUrl": ".types", "ImageUrlContent": ".types", "ImageUrlDetail": ".types", "ImageUrlEmbedContent": ".types", "InternalServerError": ".errors", "InvalidTokenError": ".errors", "JsonObjectResponseFormat": ".types", "JsonObjectResponseFormatV2": ".types", "JsonResponseFormat": ".types", "JsonResponseFormatV2": ".types", "LabelMetric": ".types", "ListBatchesResponse": ".batches", "ListConnectorsResponse": ".types", "ListEmbedJobResponse": ".types", "ListModelsResponse": ".types", "LogprobItem": ".types", "Message": ".types", "MessageEndStreamedChatResponseV2": ".aliases", "MessageEndV2ChatStreamResponse": ".v2", "MessageStartStreamedChatResponseV2": ".aliases", "MessageStartV2ChatStreamResponse": ".v2", "Metrics": ".types", "NonStreamedChatResponse": ".types", "NotFoundError": ".errors", "NotImplementedError": ".errors", "OAuthAuthorizeResponse": ".types", "OciClient": ".oci_client", "OciClientV2": ".oci_client", "ParseInfo": ".types", "RerankDocument": ".types", "RerankRequestDocumentsItem": ".types", "RerankResponse": ".types", "RerankResponseResultsItem": ".types", "RerankResponseResultsItemDocument": ".types", "RerankerDataMetrics": ".types", "ResponseFormat": ".types", "ResponseFormatV2": ".types", "SagemakerClient": ".sagemaker_client", "SagemakerClientV2": ".sagemaker_client", "SearchQueriesGenerationStreamedChatResponse": ".types", "SearchResultsStreamedChatResponse": ".types", "ServiceUnavailableError": ".errors", "SingleGeneration": ".types", "SingleGenerationInStream": ".types", "SingleGenerationTokenLikelihoodsItem": ".types", "Source": ".types", "StreamEndGenerateStreamedResponse": ".types", "StreamEndStreamedChatResponse": ".types", "StreamErrorGenerateStreamedResponse": ".types", "StreamStartStreamedChatResponse": ".types", "StreamedChatResponse": ".types", "StreamedChatResponseV2": ".aliases", "SummarizeRequestExtractiveness": ".types", "SummarizeRequestFormat": ".types", "SummarizeRequestLength": ".types", "SummarizeResponse": ".types", "SystemChatMessageV2": ".types", "SystemMessage": ".types", "SystemMessageV2": ".types", "SystemMessageV2Content": ".types", "SystemMessageV2ContentOneItem": ".types", "TextAssistantMessageResponseContentItem": ".types", "TextAssistantMessageV2ContentOneItem": ".types", "TextContent": ".types", "TextEmbedContent": ".types", "TextGenerationGenerateStreamedResponse": ".types", "TextGenerationStreamedChatResponse": ".types", "TextResponseFormat": ".types", "TextResponseFormatV2": ".types", "TextSystemMessageV2ContentOneItem": ".types", "TextToolContent": ".types", "Thinking": ".types", "ThinkingAssistantMessageResponseContentItem": ".types", "ThinkingAssistantMessageV2ContentOneItem": ".types", "ThinkingType": ".types", "TokenizeResponse": ".types", "TooManyRequestsError": ".errors", "Tool": ".types", "ToolCall": ".types", "ToolCallDelta": ".types", "ToolCallDeltaStreamedChatResponseV2": ".aliases", "ToolCallDeltaV2ChatStreamResponse": ".v2", "ToolCallEndStreamedChatResponseV2": ".aliases", "ToolCallEndV2ChatStreamResponse": ".v2", "ToolCallStartStreamedChatResponseV2": ".aliases", "ToolCallStartV2ChatStreamResponse": ".v2", "ToolCallV2": ".types", "ToolCallV2Function": ".types", "ToolCallsChunkStreamedChatResponse": ".types", "ToolCallsGenerationStreamedChatResponse": ".types", "ToolChatMessageV2": ".types", "ToolContent": ".types", "ToolMessage": ".types", "ToolMessageV2": ".types", "ToolMessageV2Content": ".types", "ToolParameterDefinitionsValue": ".types", "ToolPlanDeltaV2ChatStreamResponse": ".v2", "ToolResult": ".types", "ToolSource": ".types", "ToolV2": ".types", "ToolV2Function": ".types", "UnauthorizedError": ".errors", "UnprocessableEntityError": ".errors", "UpdateConnectorResponse": ".types", "Usage": ".types", "UsageBilledUnits": ".types", "UsageTokens": ".types", "UserChatMessageV2": ".types", "UserMessage": ".types", "UserMessageV2": ".types", "UserMessageV2Content": ".types", "V2ChatRequestDocumentsItem": ".v2", "V2ChatRequestSafetyMode": ".v2", "V2ChatRequestToolChoice": ".v2", "V2ChatResponse": ".v2", "V2ChatStreamRequestDocumentsItem": ".v2", "V2ChatStreamRequestSafetyMode": ".v2", "V2ChatStreamRequestToolChoice": ".v2", "V2ChatStreamResponse": ".v2", "V2EmbedRequestTruncate": ".v2", "V2RerankResponse": ".v2", "V2RerankResponseResultsItem": ".v2", "__version__": ".version", "audio": ".audio", "batches": ".batches", "connectors": ".connectors", "datasets": ".datasets", "embed_jobs": ".embed_jobs", "finetuning": ".finetuning", "models": ".models", "v2": ".v2", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "ApiMeta", "ApiMetaApiVersion", "ApiMetaBilledUnits", "ApiMetaTokens", "AssistantChatMessageV2", "AssistantMessage", "AssistantMessageResponse", "AssistantMessageResponseContentItem", "AssistantMessageV2Content", "AssistantMessageV2ContentOneItem", "AsyncClient", "AsyncClientV2", "AuthTokenType", "AwsClient", "BadRequestError", "Batch", "BatchStatus", "BedrockClient", "BedrockClientV2", "CancelBatchResponse", "ChatCitation", "ChatCitationGenerationEvent", "ChatCitationType", "ChatConnector", "ChatContentDeltaEvent", "ChatContentDeltaEventDelta", "ChatContentDeltaEventDeltaMessage", "ChatContentDeltaEventDeltaMessageContent", "ChatContentEndEvent", "ChatContentStartEvent", "ChatContentStartEventDelta", "ChatContentStartEventDeltaMessage", "ChatContentStartEventDeltaMessageContent", "ChatContentStartEventDeltaMessageContentType", "ChatDataMetrics", "ChatDebugEvent", "ChatDocument", "ChatDocumentSource", "ChatFinishReason", "ChatMessage", "ChatMessageEndEvent", "ChatMessageEndEventDelta", "ChatMessageStartEvent", "ChatMessageStartEventDelta", "ChatMessageStartEventDeltaMessage", "ChatMessageV2", "ChatMessages", "ChatRequestCitationQuality", "ChatRequestPromptTruncation", "ChatRequestSafetyMode", "ChatResponse", "ChatSearchQueriesGenerationEvent", "ChatSearchQuery", "ChatSearchResult", "ChatSearchResultConnector", "ChatSearchResultsEvent", "ChatStreamEndEvent", "ChatStreamEndEventFinishReason", "ChatStreamEvent", "ChatStreamEventType", "ChatStreamRequestCitationQuality", "ChatStreamRequestPromptTruncation", "ChatStreamRequestSafetyMode", "ChatStreamStartEvent", "ChatTextContent", "ChatTextGenerationEvent", "ChatTextResponseFormat", "ChatTextResponseFormatV2", "ChatThinkingContent", "ChatToolCallDeltaEvent", "ChatToolCallDeltaEventDelta", "ChatToolCallDeltaEventDeltaMessage", "ChatToolCallDeltaEventDeltaMessageToolCalls", "ChatToolCallDeltaEventDeltaMessageToolCallsFunction", "ChatToolCallEndEvent", "ChatToolCallStartEvent", "ChatToolCallStartEventDelta", "ChatToolCallStartEventDeltaMessage", "ChatToolCallsChunkEvent", "ChatToolCallsGenerationEvent", "ChatToolMessage", "ChatToolPlanDeltaEvent", "ChatToolPlanDeltaEventDelta", "ChatToolPlanDeltaEventDeltaMessage", "ChatToolSource", "ChatbotMessage", "CheckApiKeyResponse", "Citation", "CitationEndEvent", "CitationEndV2ChatStreamResponse", "CitationGenerationStreamedChatResponse", "CitationOptions", "CitationOptionsMode", "CitationStartEvent", "CitationStartEventDelta", "CitationStartEventDeltaMessage", "CitationStartV2ChatStreamResponse", "CitationType", "ClassifyDataMetrics", "ClassifyExample", "ClassifyRequestTruncate", "ClassifyResponse", "ClassifyResponseClassificationsItem", "ClassifyResponseClassificationsItemClassificationType", "ClassifyResponseClassificationsItemLabelsValue", "Client", "ClientClosedRequestError", "ClientEnvironment", "ClientV2", "CompatibleEndpoint", "Connector", "ConnectorAuthStatus", "ConnectorOAuth", "Content", "ContentDeltaStreamedChatResponseV2", "ContentDeltaV2ChatStreamResponse", "ContentEndStreamedChatResponseV2", "ContentEndV2ChatStreamResponse", "ContentStartStreamedChatResponseV2", "ContentStartV2ChatStreamResponse", "CreateBatchResponse", "CreateConnectorOAuth", "CreateConnectorResponse", "CreateConnectorServiceAuth", "CreateEmbedJobRequestTruncate", "CreateEmbedJobResponse", "Dataset", "DatasetPart", "DatasetType", "DatasetValidationStatus", "DatasetsCreateResponse", "DatasetsGetResponse", "DatasetsGetUsageResponse", "DatasetsListResponse", "DebugStreamedChatResponse", "DebugV2ChatStreamResponse", "DefaultAioHttpClient", "DefaultAsyncHttpxClient", "DeleteConnectorResponse", "DetokenizeResponse", "Document", "DocumentContent", "DocumentSource", "DocumentToolContent", "EmbedByTypeResponse", "EmbedByTypeResponseEmbeddings", "EmbedByTypeResponseResponseType", "EmbedContent", "EmbedFloatsResponse", "EmbedImage", "EmbedImageUrl", "EmbedInput", "EmbedInputType", "EmbedJob", "EmbedJobStatus", "EmbedJobTruncate", "EmbedRequestTruncate", "EmbedResponse", "EmbedText", "EmbeddingType", "EmbeddingsByTypeEmbedResponse", "EmbeddingsFloatsEmbedResponse", "FinetuneDatasetMetrics", "FinishReason", "ForbiddenError", "GatewayTimeoutError", "GenerateRequestReturnLikelihoods", "GenerateRequestTruncate", "GenerateStreamEnd", "GenerateStreamEndResponse", "GenerateStreamError", "GenerateStreamEvent", "GenerateStreamRequestReturnLikelihoods", "GenerateStreamRequestTruncate", "GenerateStreamText", "GenerateStreamedResponse", "Generation", "GetBatchResponse", "GetConnectorResponse", "GetModelResponse", "GetModelResponseSamplingDefaults", "Image", "ImageContent", "ImageUrl", "ImageUrlContent", "ImageUrlDetail", "ImageUrlEmbedContent", "InternalServerError", "InvalidTokenError", "JsonObjectResponseFormat", "JsonObjectResponseFormatV2", "JsonResponseFormat", "JsonResponseFormatV2", "LabelMetric", "ListBatchesResponse", "ListConnectorsResponse", "ListEmbedJobResponse", "ListModelsResponse", "LogprobItem", "Message", "MessageEndStreamedChatResponseV2", "MessageEndV2ChatStreamResponse", "MessageStartStreamedChatResponseV2", "MessageStartV2ChatStreamResponse", "Metrics", "NonStreamedChatResponse", "NotFoundError", "NotImplementedError", "OAuthAuthorizeResponse", "OciClient", "OciClientV2", "ParseInfo", "RerankDocument", "RerankRequestDocumentsItem", "RerankResponse", "RerankResponseResultsItem", "RerankResponseResultsItemDocument", "RerankerDataMetrics", "ResponseFormat", "ResponseFormatV2", "SagemakerClient", "SagemakerClientV2", "SearchQueriesGenerationStreamedChatResponse", "SearchResultsStreamedChatResponse", "ServiceUnavailableError", "SingleGeneration", "SingleGenerationInStream", "SingleGenerationTokenLikelihoodsItem", "Source", "StreamEndGenerateStreamedResponse", "StreamEndStreamedChatResponse", "StreamErrorGenerateStreamedResponse", "StreamStartStreamedChatResponse", "StreamedChatResponse", "StreamedChatResponseV2", "SummarizeRequestExtractiveness", "SummarizeRequestFormat", "SummarizeRequestLength", "SummarizeResponse", "SystemChatMessageV2", "SystemMessage", "SystemMessageV2", "SystemMessageV2Content", "SystemMessageV2ContentOneItem", "TextAssistantMessageResponseContentItem", "TextAssistantMessageV2ContentOneItem", "TextContent", "TextEmbedContent", "TextGenerationGenerateStreamedResponse", "TextGenerationStreamedChatResponse", "TextResponseFormat", "TextResponseFormatV2", "TextSystemMessageV2ContentOneItem", "TextToolContent", "Thinking", "ThinkingAssistantMessageResponseContentItem", "ThinkingAssistantMessageV2ContentOneItem", "ThinkingType", "TokenizeResponse", "TooManyRequestsError", "Tool", "ToolCall", "ToolCallDelta", "ToolCallDeltaStreamedChatResponseV2", "ToolCallDeltaV2ChatStreamResponse", "ToolCallEndStreamedChatResponseV2", "ToolCallEndV2ChatStreamResponse", "ToolCallStartStreamedChatResponseV2", "ToolCallStartV2ChatStreamResponse", "ToolCallV2", "ToolCallV2Function", "ToolCallsChunkStreamedChatResponse", "ToolCallsGenerationStreamedChatResponse", "ToolChatMessageV2", "ToolContent", "ToolMessage", "ToolMessageV2", "ToolMessageV2Content", "ToolParameterDefinitionsValue", "ToolPlanDeltaV2ChatStreamResponse", "ToolResult", "ToolSource", "ToolV2", "ToolV2Function", "UnauthorizedError", "UnprocessableEntityError", "UpdateConnectorResponse", "Usage", "UsageBilledUnits", "UsageTokens", "UserChatMessageV2", "UserMessage", "UserMessageV2", "UserMessageV2Content", "V2ChatRequestDocumentsItem", "V2ChatRequestSafetyMode", "V2ChatRequestToolChoice", "V2ChatResponse", "V2ChatStreamRequestDocumentsItem", "V2ChatStreamRequestSafetyMode", "V2ChatStreamRequestToolChoice", "V2ChatStreamResponse", "V2EmbedRequestTruncate", "V2RerankResponse", "V2RerankResponseResultsItem", "__version__", "audio", "batches", "connectors", "datasets", "embed_jobs", "finetuning", "models", "v2", ] ================================================ FILE: src/cohere/_default_clients.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import httpx SDK_DEFAULT_TIMEOUT = 60 try: import httpx_aiohttp # type: ignore[import-not-found] except ImportError: class DefaultAioHttpClient(httpx.AsyncClient): # type: ignore def __init__(self, **kwargs: typing.Any) -> None: raise RuntimeError("To use the aiohttp client, install the aiohttp extra: pip install cohere[aiohttp]") else: class DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore def __init__(self, **kwargs: typing.Any) -> None: kwargs.setdefault("timeout", SDK_DEFAULT_TIMEOUT) kwargs.setdefault("follow_redirects", True) super().__init__(**kwargs) class DefaultAsyncHttpxClient(httpx.AsyncClient): def __init__(self, **kwargs: typing.Any) -> None: kwargs.setdefault("timeout", SDK_DEFAULT_TIMEOUT) kwargs.setdefault("follow_redirects", True) super().__init__(**kwargs) ================================================ FILE: src/cohere/aliases.py ================================================ # Import overrides early to ensure they're applied before types are used # This is necessary for backwards compatibility patches like ToolCallV2.id being optional from . import overrides # noqa: F401 from .v2 import ( ContentDeltaV2ChatStreamResponse, ContentEndV2ChatStreamResponse, ContentStartV2ChatStreamResponse, MessageEndV2ChatStreamResponse, MessageStartV2ChatStreamResponse, ToolCallDeltaV2ChatStreamResponse, ToolCallEndV2ChatStreamResponse, ToolCallStartV2ChatStreamResponse, V2ChatStreamResponse, V2ChatResponse ) # alias classes StreamedChatResponseV2 = V2ChatStreamResponse MessageStartStreamedChatResponseV2 = MessageStartV2ChatStreamResponse MessageEndStreamedChatResponseV2 = MessageEndV2ChatStreamResponse ContentStartStreamedChatResponseV2 = ContentStartV2ChatStreamResponse ContentDeltaStreamedChatResponseV2 = ContentDeltaV2ChatStreamResponse ContentEndStreamedChatResponseV2 = ContentEndV2ChatStreamResponse ToolCallStartStreamedChatResponseV2 = ToolCallStartV2ChatStreamResponse ToolCallDeltaStreamedChatResponseV2 = ToolCallDeltaV2ChatStreamResponse ToolCallEndStreamedChatResponseV2 = ToolCallEndV2ChatStreamResponse ChatResponse = V2ChatResponse ================================================ FILE: src/cohere/audio/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from . import transcriptions from .transcriptions import AudioTranscriptionsCreateResponse _dynamic_imports: typing.Dict[str, str] = { "AudioTranscriptionsCreateResponse": ".transcriptions", "transcriptions": ".transcriptions", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["AudioTranscriptionsCreateResponse", "transcriptions"] ================================================ FILE: src/cohere/audio/client.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from .raw_client import AsyncRawAudioClient, RawAudioClient if typing.TYPE_CHECKING: from .transcriptions.client import AsyncTranscriptionsClient, TranscriptionsClient class AudioClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawAudioClient(client_wrapper=client_wrapper) self._client_wrapper = client_wrapper self._transcriptions: typing.Optional[TranscriptionsClient] = None @property def with_raw_response(self) -> RawAudioClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawAudioClient """ return self._raw_client @property def transcriptions(self): if self._transcriptions is None: from .transcriptions.client import TranscriptionsClient # noqa: E402 self._transcriptions = TranscriptionsClient(client_wrapper=self._client_wrapper) return self._transcriptions class AsyncAudioClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawAudioClient(client_wrapper=client_wrapper) self._client_wrapper = client_wrapper self._transcriptions: typing.Optional[AsyncTranscriptionsClient] = None @property def with_raw_response(self) -> AsyncRawAudioClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawAudioClient """ return self._raw_client @property def transcriptions(self): if self._transcriptions is None: from .transcriptions.client import AsyncTranscriptionsClient # noqa: E402 self._transcriptions = AsyncTranscriptionsClient(client_wrapper=self._client_wrapper) return self._transcriptions ================================================ FILE: src/cohere/audio/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper class RawAudioClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper class AsyncRawAudioClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper ================================================ FILE: src/cohere/audio/transcriptions/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import AudioTranscriptionsCreateResponse _dynamic_imports: typing.Dict[str, str] = {"AudioTranscriptionsCreateResponse": ".types"} def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["AudioTranscriptionsCreateResponse"] ================================================ FILE: src/cohere/audio/transcriptions/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ... import core from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ...core.request_options import RequestOptions from .raw_client import AsyncRawTranscriptionsClient, RawTranscriptionsClient from .types.audio_transcriptions_create_response import AudioTranscriptionsCreateResponse # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class TranscriptionsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawTranscriptionsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawTranscriptionsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawTranscriptionsClient """ return self._raw_client def create( self, *, model: str, language: str, file: core.File, temperature: typing.Optional[float] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AudioTranscriptionsCreateResponse: """ Transcribe an audio file. Parameters ---------- model : str ID of the model to use. language : str The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format. file : core.File See core.File for more documentation temperature : typing.Optional[float] The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AudioTranscriptionsCreateResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.audio.transcriptions.create( model="model", language="language", ) """ _response = self._raw_client.create( model=model, language=language, file=file, temperature=temperature, request_options=request_options ) return _response.data class AsyncTranscriptionsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawTranscriptionsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawTranscriptionsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawTranscriptionsClient """ return self._raw_client async def create( self, *, model: str, language: str, file: core.File, temperature: typing.Optional[float] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AudioTranscriptionsCreateResponse: """ Transcribe an audio file. Parameters ---------- model : str ID of the model to use. language : str The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format. file : core.File See core.File for more documentation temperature : typing.Optional[float] The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AudioTranscriptionsCreateResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.audio.transcriptions.create( model="model", language="language", ) asyncio.run(main()) """ _response = await self._raw_client.create( model=model, language=language, file=file, temperature=temperature, request_options=request_options ) return _response.data ================================================ FILE: src/cohere/audio/transcriptions/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ... import core from ...core.api_error import ApiError from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ...core.http_response import AsyncHttpResponse, HttpResponse from ...core.parse_error import ParsingError from ...core.request_options import RequestOptions from ...core.unchecked_base_model import construct_type from ...errors.bad_request_error import BadRequestError from ...errors.client_closed_request_error import ClientClosedRequestError from ...errors.forbidden_error import ForbiddenError from ...errors.gateway_timeout_error import GatewayTimeoutError from ...errors.internal_server_error import InternalServerError from ...errors.invalid_token_error import InvalidTokenError from ...errors.not_found_error import NotFoundError from ...errors.not_implemented_error import NotImplementedError from ...errors.service_unavailable_error import ServiceUnavailableError from ...errors.too_many_requests_error import TooManyRequestsError from ...errors.unauthorized_error import UnauthorizedError from ...errors.unprocessable_entity_error import UnprocessableEntityError from .types.audio_transcriptions_create_response import AudioTranscriptionsCreateResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawTranscriptionsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def create( self, *, model: str, language: str, file: core.File, temperature: typing.Optional[float] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[AudioTranscriptionsCreateResponse]: """ Transcribe an audio file. Parameters ---------- model : str ID of the model to use. language : str The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format. file : core.File See core.File for more documentation temperature : typing.Optional[float] The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[AudioTranscriptionsCreateResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v2/audio/transcriptions", method="POST", data={ "model": model, "language": language, "temperature": temperature, }, files={ "file": file, }, request_options=request_options, omit=OMIT, force_multipart=True, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( AudioTranscriptionsCreateResponse, construct_type( type_=AudioTranscriptionsCreateResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawTranscriptionsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def create( self, *, model: str, language: str, file: core.File, temperature: typing.Optional[float] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[AudioTranscriptionsCreateResponse]: """ Transcribe an audio file. Parameters ---------- model : str ID of the model to use. language : str The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format. file : core.File See core.File for more documentation temperature : typing.Optional[float] The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[AudioTranscriptionsCreateResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v2/audio/transcriptions", method="POST", data={ "model": model, "language": language, "temperature": temperature, }, files={ "file": file, }, request_options=request_options, omit=OMIT, force_multipart=True, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( AudioTranscriptionsCreateResponse, construct_type( type_=AudioTranscriptionsCreateResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/audio/transcriptions/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .audio_transcriptions_create_response import AudioTranscriptionsCreateResponse _dynamic_imports: typing.Dict[str, str] = {"AudioTranscriptionsCreateResponse": ".audio_transcriptions_create_response"} def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["AudioTranscriptionsCreateResponse"] ================================================ FILE: src/cohere/audio/transcriptions/types/audio_transcriptions_create_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel class AudioTranscriptionsCreateResponse(UncheckedBaseModel): text: str = pydantic.Field() """ The transcribed text. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/aws_client.py ================================================ import base64 import json import re import typing import httpx from httpx import URL, SyncByteStream, ByteStream from . import GenerateStreamedResponse, Generation, \ NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \ ApiMetaBilledUnits from .client import Client, ClientEnvironment from .core import construct_type from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore from .client_v2 import ClientV2 class AwsClient(Client): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): Client.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), timeout=timeout, ), ) class AwsClientV2(ClientV2): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): ClientV2.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), timeout=timeout, ), ) EventHook = typing.Callable[..., typing.Any] def get_event_hooks( service: str, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, ) -> typing.Dict[str, typing.List[EventHook]]: return { "request": [ map_request_to_bedrock( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), ], "response": [ map_response_from_bedrock() ], } TextGeneration = typing.TypedDict('TextGeneration', {"text": str, "is_finished": str, "event_type": typing.Literal["text-generation"]}) StreamEnd = typing.TypedDict('StreamEnd', {"is_finished": str, "event_type": typing.Literal["stream-end"], "finish_reason": str, # "amazon-bedrock-invocationMetrics": { # "inputTokenCount": int, "outputTokenCount": int, "invocationLatency": int, # "firstByteLatency": int} }) class Streamer(SyncByteStream): lines: typing.Iterator[bytes] def __init__(self, lines: typing.Iterator[bytes]): self.lines = lines def __iter__(self) -> typing.Iterator[bytes]: return self.lines response_mapping: typing.Dict[str, typing.Any] = { "chat": NonStreamedChatResponse, "embed": EmbedResponse, "generate": Generation, "rerank": RerankResponse } stream_response_mapping: typing.Dict[str, typing.Any] = { "chat": StreamedChatResponse, "generate": GenerateStreamedResponse, } def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator[bytes]: regex = r"{[^\}]*}" for _text in response.iter_lines(): match = re.search(regex, _text) if match: obj = json.loads(match.group()) if "bytes" in obj: base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8") streamed_obj = json.loads(base64_payload) if "event_type" in streamed_obj: response_type = stream_response_mapping[endpoint] parsed = typing.cast(response_type, # type: ignore construct_type(type_=response_type, object_=streamed_obj)) yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore def map_token_counts(response: httpx.Response) -> ApiMeta: input_tokens = int(response.headers.get("X-Amzn-Bedrock-Input-Token-Count", -1)) output_tokens = int(response.headers.get("X-Amzn-Bedrock-Output-Token-Count", -1)) return ApiMeta( tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens), billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens), ) def map_response_from_bedrock(): def _hook( response: httpx.Response, ) -> None: stream = response.headers["content-type"] == "application/vnd.amazon.eventstream" endpoint = response.request.extensions["endpoint"] output: typing.Iterator[bytes] if stream: output = stream_generator(httpx.Response( stream=response.stream, status_code=response.status_code, ), endpoint) else: response_type = response_mapping[endpoint] response_obj = json.loads(response.read()) response_obj["meta"] = map_token_counts(response).dict() cast_obj: typing.Any = typing.cast(response_type, # type: ignore construct_type( type_=response_type, # type: ignore object_=response_obj)) output = iter([json.dumps(cast_obj.dict()).encode("utf-8")]) response.stream = Streamer(output) # reset response object to allow for re-reading if hasattr(response, "_content"): del response._content response.is_stream_consumed = False response.is_closed = False return _hook def get_boto3_session( **kwargs: typing.Any, ): non_none_args = {k: v for k, v in kwargs.items() if v is not None} return lazy_boto3().Session(**non_none_args) def map_request_to_bedrock( service: str, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, ) -> EventHook: session = get_boto3_session( region_name=aws_region, aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key, aws_session_token=aws_session_token, ) aws_region = session.region_name credentials = session.get_credentials() signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region) def _event_hook(request: httpx.Request) -> None: headers = request.headers.copy() del headers["connection"] api_version = request.url.path.split("/")[-2] endpoint = request.url.path.split("/")[-1] body = json.loads(request.read()) model = body["model"] url = get_url( platform=service, aws_region=aws_region, model=model, # type: ignore stream="stream" in body and body["stream"], ) request.url = URL(url) request.headers["host"] = request.url.host headers["host"] = request.url.host if endpoint == "rerank": body["api_version"] = get_api_version(version=api_version) if "stream" in body: del body["stream"] if "model" in body: del body["model"] new_body = json.dumps(body).encode("utf-8") request.stream = ByteStream(new_body) request._content = new_body headers["content-length"] = str(len(new_body)) aws_request = lazy_botocore().awsrequest.AWSRequest( method=request.method, url=url, headers=headers, data=request.read(), ) signer.add_auth(aws_request) request.headers = httpx.Headers(aws_request.prepare().headers) request.extensions["endpoint"] = endpoint return _event_hook def get_url( *, platform: str, aws_region: typing.Optional[str], model: str, stream: bool, ) -> str: if platform == "bedrock": endpoint = "invoke" if not stream else "invoke-with-response-stream" return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}" elif platform == "sagemaker": endpoint = "invocations" if not stream else "invocations-response-stream" return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}" return "" def get_api_version(*, version: str): int_version = { "v1": 1, "v2": 2, } return int_version.get(version, 1) ================================================ FILE: src/cohere/base_client.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import os import typing import httpx from .core.api_error import ApiError from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from .core.logging import LogConfig, Logger from .core.request_options import RequestOptions from .environment import ClientEnvironment from .raw_base_client import AsyncRawBaseCohere, RawBaseCohere from .types.chat_connector import ChatConnector from .types.chat_document import ChatDocument from .types.chat_request_citation_quality import ChatRequestCitationQuality from .types.chat_request_prompt_truncation import ChatRequestPromptTruncation from .types.chat_request_safety_mode import ChatRequestSafetyMode from .types.chat_stream_request_citation_quality import ChatStreamRequestCitationQuality from .types.chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation from .types.chat_stream_request_safety_mode import ChatStreamRequestSafetyMode from .types.check_api_key_response import CheckApiKeyResponse from .types.classify_example import ClassifyExample from .types.classify_request_truncate import ClassifyRequestTruncate from .types.classify_response import ClassifyResponse from .types.detokenize_response import DetokenizeResponse from .types.embed_input_type import EmbedInputType from .types.embed_request_truncate import EmbedRequestTruncate from .types.embed_response import EmbedResponse from .types.embedding_type import EmbeddingType from .types.generate_request_return_likelihoods import GenerateRequestReturnLikelihoods from .types.generate_request_truncate import GenerateRequestTruncate from .types.generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods from .types.generate_stream_request_truncate import GenerateStreamRequestTruncate from .types.generate_streamed_response import GenerateStreamedResponse from .types.generation import Generation from .types.message import Message from .types.non_streamed_chat_response import NonStreamedChatResponse from .types.rerank_request_documents_item import RerankRequestDocumentsItem from .types.rerank_response import RerankResponse from .types.response_format import ResponseFormat from .types.streamed_chat_response import StreamedChatResponse from .types.summarize_request_extractiveness import SummarizeRequestExtractiveness from .types.summarize_request_format import SummarizeRequestFormat from .types.summarize_request_length import SummarizeRequestLength from .types.summarize_response import SummarizeResponse from .types.tokenize_response import TokenizeResponse from .types.tool import Tool from .types.tool_result import ToolResult if typing.TYPE_CHECKING: from .audio.client import AsyncAudioClient, AudioClient from .batches.client import AsyncBatchesClient, BatchesClient from .connectors.client import AsyncConnectorsClient, ConnectorsClient from .datasets.client import AsyncDatasetsClient, DatasetsClient from .embed_jobs.client import AsyncEmbedJobsClient, EmbedJobsClient from .finetuning.client import AsyncFinetuningClient, FinetuningClient from .models.client import AsyncModelsClient, ModelsClient from .v2.client import AsyncV2Client, V2Client # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class BaseCohere: """ Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions. Parameters ---------- base_url : typing.Optional[str] The base url to use for requests from the client. environment : ClientEnvironment The environment to use for requests from the client. from .environment import ClientEnvironment Defaults to ClientEnvironment.PRODUCTION client_name : typing.Optional[str] token : typing.Optional[typing.Union[str, typing.Callable[[], str]]] headers : typing.Optional[typing.Dict[str, str]] Additional headers to send with every request. timeout : typing.Optional[float] The timeout to be used, in seconds, for requests. By default the timeout is 300 seconds, unless a custom httpx client is used, in which case this default is not enforced. follow_redirects : typing.Optional[bool] Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in. httpx_client : typing.Optional[httpx.Client] The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration. logging : typing.Optional[typing.Union[LogConfig, Logger]] Configure logging for the SDK. Accepts a LogConfig dict with 'level' (debug/info/warn/error), 'logger' (custom logger implementation), and 'silent' (boolean, defaults to True) fields. You can also pass a pre-configured Logger instance. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) """ def __init__( self, *, base_url: typing.Optional[str] = None, environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, token: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = os.getenv("CO_API_KEY"), headers: typing.Optional[typing.Dict[str, str]] = None, timeout: typing.Optional[float] = None, follow_redirects: typing.Optional[bool] = True, httpx_client: typing.Optional[httpx.Client] = None, logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, ): _defaulted_timeout = ( timeout if timeout is not None else 300 if httpx_client is None else httpx_client.timeout.read ) if token is None: raise ApiError(body="The client must be instantiated be either passing in token or setting CO_API_KEY") self._client_wrapper = SyncClientWrapper( base_url=_get_base_url(base_url=base_url, environment=environment), client_name=client_name, token=token, headers=headers, httpx_client=httpx_client if httpx_client is not None else httpx.Client(timeout=_defaulted_timeout, follow_redirects=follow_redirects) if follow_redirects is not None else httpx.Client(timeout=_defaulted_timeout), timeout=_defaulted_timeout, logging=logging, ) self._raw_client = RawBaseCohere(client_wrapper=self._client_wrapper) self._v2: typing.Optional[V2Client] = None self._batches: typing.Optional[BatchesClient] = None self._embed_jobs: typing.Optional[EmbedJobsClient] = None self._datasets: typing.Optional[DatasetsClient] = None self._connectors: typing.Optional[ConnectorsClient] = None self._models: typing.Optional[ModelsClient] = None self._finetuning: typing.Optional[FinetuningClient] = None self._audio: typing.Optional[AudioClient] = None @property def with_raw_response(self) -> RawBaseCohere: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawBaseCohere """ return self._raw_client def chat_stream( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[StreamedChatResponse]: """ Generates a streamed text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatStreamRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[StreamedChatResponse] Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) response = client.chat_stream( model="command-a-03-2025", message="hello!", ) for chunk in response: yield chunk """ with self._raw_client.chat_stream( message=message, accepts=accepts, model=model, preamble=preamble, chat_history=chat_history, conversation_id=conversation_id, prompt_truncation=prompt_truncation, connectors=connectors, search_queries_only=search_queries_only, documents=documents, citation_quality=citation_quality, temperature=temperature, max_tokens=max_tokens, max_input_tokens=max_input_tokens, k=k, p=p, seed=seed, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, raw_prompting=raw_prompting, tools=tools, tool_results=tool_results, force_single_step=force_single_step, response_format=response_format, safety_mode=safety_mode, request_options=request_options, ) as r: yield from r.data def chat( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> NonStreamedChatResponse: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- NonStreamedChatResponse Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.chat( model="command-a-03-2025", message="Tell me about LLMs", ) """ _response = self._raw_client.chat( message=message, accepts=accepts, model=model, preamble=preamble, chat_history=chat_history, conversation_id=conversation_id, prompt_truncation=prompt_truncation, connectors=connectors, search_queries_only=search_queries_only, documents=documents, citation_quality=citation_quality, temperature=temperature, max_tokens=max_tokens, max_input_tokens=max_input_tokens, k=k, p=p, seed=seed, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, raw_prompting=raw_prompting, tools=tools, tool_results=tool_results, force_single_step=force_single_step, response_format=response_format, safety_mode=safety_mode, request_options=request_options, ) return _response.data def generate_stream( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[GenerateStreamedResponse]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateStreamRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[GenerateStreamedResponse] Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) response = client.generate_stream( prompt="Please explain to me how LLMs work", ) for chunk in response: yield chunk """ with self._raw_client.generate_stream( prompt=prompt, model=model, num_generations=num_generations, max_tokens=max_tokens, truncate=truncate, temperature=temperature, seed=seed, preset=preset, end_sequences=end_sequences, stop_sequences=stop_sequences, k=k, p=p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, return_likelihoods=return_likelihoods, raw_prompting=raw_prompting, request_options=request_options, ) as r: yield from r.data def generate( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> Generation: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- Generation Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.generate( prompt="Please explain to me how LLMs work", ) """ _response = self._raw_client.generate( prompt=prompt, model=model, num_generations=num_generations, max_tokens=max_tokens, truncate=truncate, temperature=temperature, seed=seed, preset=preset, end_sequences=end_sequences, stop_sequences=stop_sequences, k=k, p=p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, return_likelihoods=return_likelihoods, raw_prompting=raw_prompting, request_options=request_options, ) return _response.data def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> EmbedResponse: """ This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents. Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Images are only supported with Embed v3.0 and newer models. model : typing.Optional[str] ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : typing.Optional[EmbedInputType] embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.embed( texts=["hello", "goodbye"], model="embed-v4.0", input_type="classification", ) """ _response = self._raw_client.embed( texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) return _response.data def rerank( self, *, query: str, documents: typing.Sequence[RerankRequestDocumentsItem], model: typing.Optional[str] = OMIT, top_n: typing.Optional[int] = OMIT, rank_fields: typing.Optional[typing.Sequence[str]] = OMIT, return_documents: typing.Optional[bool] = OMIT, max_chunks_per_doc: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> RerankResponse: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- query : str The search query documents : typing.Sequence[RerankRequestDocumentsItem] A list of document objects or strings to rerank. If a document is provided the text fields is required and all other fields will be preserved in the response. The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. We recommend a maximum of 1,000 documents for optimal endpoint performance. model : typing.Optional[str] The identifier of the model to use, eg `rerank-v3.5`. top_n : typing.Optional[int] The number of most relevant documents or indices to return, defaults to the length of the documents rank_fields : typing.Optional[typing.Sequence[str]] If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking. return_documents : typing.Optional[bool] - If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request. - If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request. max_chunks_per_doc : typing.Optional[int] The maximum number of chunks to produce internally from a document request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- RerankResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.rerank( documents=[ { "text": "Carson City is the capital city of the American state of Nevada." }, { "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }, { "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages." }, { "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }, { "text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }, ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) """ _response = self._raw_client.rerank( query=query, documents=documents, model=model, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, request_options=request_options, ) return _response.data def classify( self, *, inputs: typing.Sequence[str], examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT, model: typing.Optional[str] = OMIT, preset: typing.Optional[str] = OMIT, truncate: typing.Optional[ClassifyRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> ClassifyResponse: """ This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. Parameters ---------- inputs : typing.Sequence[str] A list of up to 96 texts to be classified. Each one must be a non-empty string. There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models). Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts. examples : typing.Optional[typing.Sequence[ClassifyExample]] An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. model : typing.Optional[str] ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model preset : typing.Optional[str] The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters. truncate : typing.Optional[ClassifyRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ClassifyResponse OK Examples -------- from cohere import ClassifyExample, Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.classify( examples=[ ClassifyExample( text="Dermatologists don't like her!", label="Spam", ), ClassifyExample( text="'Hello, open to this?'", label="Spam", ), ClassifyExample( text="I need help please wire me $1000 right now", label="Spam", ), ClassifyExample( text="Nice to know you ;)", label="Spam", ), ClassifyExample( text="Please help me?", label="Spam", ), ClassifyExample( text="Your parcel will be delivered today", label="Not spam", ), ClassifyExample( text="Review changes to our Terms and Conditions", label="Not spam", ), ClassifyExample( text="Weekly sync notes", label="Not spam", ), ClassifyExample( text="'Re: Follow up from today's meeting'", label="Not spam", ), ClassifyExample( text="Pre-read for tomorrow", label="Not spam", ), ], inputs=["Confirm your email address", "hey i need u to send some $"], model="YOUR-FINE-TUNED-MODEL-ID", ) """ _response = self._raw_client.classify( inputs=inputs, examples=examples, model=model, preset=preset, truncate=truncate, request_options=request_options, ) return _response.data def summarize( self, *, text: str, length: typing.Optional[SummarizeRequestLength] = OMIT, format: typing.Optional[SummarizeRequestFormat] = OMIT, model: typing.Optional[str] = OMIT, extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT, temperature: typing.Optional[float] = OMIT, additional_command: typing.Optional[str] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> SummarizeResponse: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates a summary in English for a given text. Parameters ---------- text : str The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English. length : typing.Optional[SummarizeRequestLength] One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. format : typing.Optional[SummarizeRequestFormat] One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. model : typing.Optional[str] The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. extractiveness : typing.Optional[SummarizeRequestExtractiveness] One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. temperature : typing.Optional[float] Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1. additional_command : typing.Optional[str] A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda" request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- SummarizeResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.summarize( text='Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name "ice cream" varies from one country to another. In some countries, such as the United States, "ice cream" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled "frozen dairy dessert" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.', ) """ _response = self._raw_client.summarize( text=text, length=length, format=format, model=model, extractiveness=extractiveness, temperature=temperature, additional_command=additional_command, request_options=request_options, ) return _response.data def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None ) -> TokenizeResponse: """ This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- text : str The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters. model : str The input will be tokenized by the tokenizer that is used by this model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- TokenizeResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.tokenize( text="tokenize me! :D", model="command", ) """ _response = self._raw_client.tokenize(text=text, model=model, request_options=request_options) return _response.data def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None ) -> DetokenizeResponse: """ This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- tokens : typing.Sequence[int] The list of tokens to be detokenized. model : str An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DetokenizeResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.detokenize( tokens=[10002, 2261, 2012, 8, 2792, 43], model="command", ) """ _response = self._raw_client.detokenize(tokens=tokens, model=model, request_options=request_options) return _response.data def check_api_key(self, *, request_options: typing.Optional[RequestOptions] = None) -> CheckApiKeyResponse: """ Checks that the api key in the Authorization header is valid and active Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CheckApiKeyResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.check_api_key() """ _response = self._raw_client.check_api_key(request_options=request_options) return _response.data @property def v2(self): if self._v2 is None: from .v2.client import V2Client # noqa: E402 self._v2 = V2Client(client_wrapper=self._client_wrapper) return self._v2 @property def batches(self): if self._batches is None: from .batches.client import BatchesClient # noqa: E402 self._batches = BatchesClient(client_wrapper=self._client_wrapper) return self._batches @property def embed_jobs(self): if self._embed_jobs is None: from .embed_jobs.client import EmbedJobsClient # noqa: E402 self._embed_jobs = EmbedJobsClient(client_wrapper=self._client_wrapper) return self._embed_jobs @property def datasets(self): if self._datasets is None: from .datasets.client import DatasetsClient # noqa: E402 self._datasets = DatasetsClient(client_wrapper=self._client_wrapper) return self._datasets @property def connectors(self): if self._connectors is None: from .connectors.client import ConnectorsClient # noqa: E402 self._connectors = ConnectorsClient(client_wrapper=self._client_wrapper) return self._connectors @property def models(self): if self._models is None: from .models.client import ModelsClient # noqa: E402 self._models = ModelsClient(client_wrapper=self._client_wrapper) return self._models @property def finetuning(self): if self._finetuning is None: from .finetuning.client import FinetuningClient # noqa: E402 self._finetuning = FinetuningClient(client_wrapper=self._client_wrapper) return self._finetuning @property def audio(self): if self._audio is None: from .audio.client import AudioClient # noqa: E402 self._audio = AudioClient(client_wrapper=self._client_wrapper) return self._audio def _make_default_async_client( timeout: typing.Optional[float], follow_redirects: typing.Optional[bool], ) -> httpx.AsyncClient: try: import httpx_aiohttp # type: ignore[import-not-found] except ImportError: pass else: if follow_redirects is not None: return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout, follow_redirects=follow_redirects) return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout) if follow_redirects is not None: return httpx.AsyncClient(timeout=timeout, follow_redirects=follow_redirects) return httpx.AsyncClient(timeout=timeout) class AsyncBaseCohere: """ Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions. Parameters ---------- base_url : typing.Optional[str] The base url to use for requests from the client. environment : ClientEnvironment The environment to use for requests from the client. from .environment import ClientEnvironment Defaults to ClientEnvironment.PRODUCTION client_name : typing.Optional[str] token : typing.Optional[typing.Union[str, typing.Callable[[], str]]] headers : typing.Optional[typing.Dict[str, str]] Additional headers to send with every request. async_token : typing.Optional[typing.Callable[[], typing.Awaitable[str]]] An async callable that returns a bearer token. Use this when token acquisition involves async I/O (e.g., refreshing tokens via an async HTTP client). When provided, this is used instead of the synchronous token for async requests. timeout : typing.Optional[float] The timeout to be used, in seconds, for requests. By default the timeout is 300 seconds, unless a custom httpx client is used, in which case this default is not enforced. follow_redirects : typing.Optional[bool] Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in. httpx_client : typing.Optional[httpx.AsyncClient] The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration. logging : typing.Optional[typing.Union[LogConfig, Logger]] Configure logging for the SDK. Accepts a LogConfig dict with 'level' (debug/info/warn/error), 'logger' (custom logger implementation), and 'silent' (boolean, defaults to True) fields. You can also pass a pre-configured Logger instance. Examples -------- from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) """ def __init__( self, *, base_url: typing.Optional[str] = None, environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, token: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = os.getenv("CO_API_KEY"), headers: typing.Optional[typing.Dict[str, str]] = None, async_token: typing.Optional[typing.Callable[[], typing.Awaitable[str]]] = None, timeout: typing.Optional[float] = None, follow_redirects: typing.Optional[bool] = True, httpx_client: typing.Optional[httpx.AsyncClient] = None, logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, ): _defaulted_timeout = ( timeout if timeout is not None else 300 if httpx_client is None else httpx_client.timeout.read ) if token is None: raise ApiError(body="The client must be instantiated be either passing in token or setting CO_API_KEY") self._client_wrapper = AsyncClientWrapper( base_url=_get_base_url(base_url=base_url, environment=environment), client_name=client_name, token=token, headers=headers, async_token=async_token, httpx_client=httpx_client if httpx_client is not None else _make_default_async_client(timeout=_defaulted_timeout, follow_redirects=follow_redirects), timeout=_defaulted_timeout, logging=logging, ) self._raw_client = AsyncRawBaseCohere(client_wrapper=self._client_wrapper) self._v2: typing.Optional[AsyncV2Client] = None self._batches: typing.Optional[AsyncBatchesClient] = None self._embed_jobs: typing.Optional[AsyncEmbedJobsClient] = None self._datasets: typing.Optional[AsyncDatasetsClient] = None self._connectors: typing.Optional[AsyncConnectorsClient] = None self._models: typing.Optional[AsyncModelsClient] = None self._finetuning: typing.Optional[AsyncFinetuningClient] = None self._audio: typing.Optional[AsyncAudioClient] = None @property def with_raw_response(self) -> AsyncRawBaseCohere: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawBaseCohere """ return self._raw_client async def chat_stream( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[StreamedChatResponse]: """ Generates a streamed text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatStreamRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[StreamedChatResponse] Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: response = await client.chat_stream( model="command-a-03-2025", message="hello!", ) async for chunk in response: yield chunk asyncio.run(main()) """ async with self._raw_client.chat_stream( message=message, accepts=accepts, model=model, preamble=preamble, chat_history=chat_history, conversation_id=conversation_id, prompt_truncation=prompt_truncation, connectors=connectors, search_queries_only=search_queries_only, documents=documents, citation_quality=citation_quality, temperature=temperature, max_tokens=max_tokens, max_input_tokens=max_input_tokens, k=k, p=p, seed=seed, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, raw_prompting=raw_prompting, tools=tools, tool_results=tool_results, force_single_step=force_single_step, response_format=response_format, safety_mode=safety_mode, request_options=request_options, ) as r: async for _chunk in r.data: yield _chunk async def chat( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> NonStreamedChatResponse: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- NonStreamedChatResponse Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.chat( model="command-a-03-2025", message="Tell me about LLMs", ) asyncio.run(main()) """ _response = await self._raw_client.chat( message=message, accepts=accepts, model=model, preamble=preamble, chat_history=chat_history, conversation_id=conversation_id, prompt_truncation=prompt_truncation, connectors=connectors, search_queries_only=search_queries_only, documents=documents, citation_quality=citation_quality, temperature=temperature, max_tokens=max_tokens, max_input_tokens=max_input_tokens, k=k, p=p, seed=seed, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, raw_prompting=raw_prompting, tools=tools, tool_results=tool_results, force_single_step=force_single_step, response_format=response_format, safety_mode=safety_mode, request_options=request_options, ) return _response.data async def generate_stream( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[GenerateStreamedResponse]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateStreamRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[GenerateStreamedResponse] Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: response = await client.generate_stream( prompt="Please explain to me how LLMs work", ) async for chunk in response: yield chunk asyncio.run(main()) """ async with self._raw_client.generate_stream( prompt=prompt, model=model, num_generations=num_generations, max_tokens=max_tokens, truncate=truncate, temperature=temperature, seed=seed, preset=preset, end_sequences=end_sequences, stop_sequences=stop_sequences, k=k, p=p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, return_likelihoods=return_likelihoods, raw_prompting=raw_prompting, request_options=request_options, ) as r: async for _chunk in r.data: yield _chunk async def generate( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> Generation: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- Generation Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.generate( prompt="Please explain to me how LLMs work", ) asyncio.run(main()) """ _response = await self._raw_client.generate( prompt=prompt, model=model, num_generations=num_generations, max_tokens=max_tokens, truncate=truncate, temperature=temperature, seed=seed, preset=preset, end_sequences=end_sequences, stop_sequences=stop_sequences, k=k, p=p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, return_likelihoods=return_likelihoods, raw_prompting=raw_prompting, request_options=request_options, ) return _response.data async def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> EmbedResponse: """ This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents. Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Images are only supported with Embed v3.0 and newer models. model : typing.Optional[str] ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : typing.Optional[EmbedInputType] embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.embed( texts=["hello", "goodbye"], model="embed-v4.0", input_type="classification", ) asyncio.run(main()) """ _response = await self._raw_client.embed( texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) return _response.data async def rerank( self, *, query: str, documents: typing.Sequence[RerankRequestDocumentsItem], model: typing.Optional[str] = OMIT, top_n: typing.Optional[int] = OMIT, rank_fields: typing.Optional[typing.Sequence[str]] = OMIT, return_documents: typing.Optional[bool] = OMIT, max_chunks_per_doc: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> RerankResponse: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- query : str The search query documents : typing.Sequence[RerankRequestDocumentsItem] A list of document objects or strings to rerank. If a document is provided the text fields is required and all other fields will be preserved in the response. The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. We recommend a maximum of 1,000 documents for optimal endpoint performance. model : typing.Optional[str] The identifier of the model to use, eg `rerank-v3.5`. top_n : typing.Optional[int] The number of most relevant documents or indices to return, defaults to the length of the documents rank_fields : typing.Optional[typing.Sequence[str]] If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking. return_documents : typing.Optional[bool] - If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request. - If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request. max_chunks_per_doc : typing.Optional[int] The maximum number of chunks to produce internally from a document request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- RerankResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.rerank( documents=[ { "text": "Carson City is the capital city of the American state of Nevada." }, { "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }, { "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages." }, { "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }, { "text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }, ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) asyncio.run(main()) """ _response = await self._raw_client.rerank( query=query, documents=documents, model=model, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, request_options=request_options, ) return _response.data async def classify( self, *, inputs: typing.Sequence[str], examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT, model: typing.Optional[str] = OMIT, preset: typing.Optional[str] = OMIT, truncate: typing.Optional[ClassifyRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> ClassifyResponse: """ This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. Parameters ---------- inputs : typing.Sequence[str] A list of up to 96 texts to be classified. Each one must be a non-empty string. There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models). Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts. examples : typing.Optional[typing.Sequence[ClassifyExample]] An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. model : typing.Optional[str] ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model preset : typing.Optional[str] The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters. truncate : typing.Optional[ClassifyRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ClassifyResponse OK Examples -------- import asyncio from cohere import AsyncClient, ClassifyExample client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.classify( examples=[ ClassifyExample( text="Dermatologists don't like her!", label="Spam", ), ClassifyExample( text="'Hello, open to this?'", label="Spam", ), ClassifyExample( text="I need help please wire me $1000 right now", label="Spam", ), ClassifyExample( text="Nice to know you ;)", label="Spam", ), ClassifyExample( text="Please help me?", label="Spam", ), ClassifyExample( text="Your parcel will be delivered today", label="Not spam", ), ClassifyExample( text="Review changes to our Terms and Conditions", label="Not spam", ), ClassifyExample( text="Weekly sync notes", label="Not spam", ), ClassifyExample( text="'Re: Follow up from today's meeting'", label="Not spam", ), ClassifyExample( text="Pre-read for tomorrow", label="Not spam", ), ], inputs=["Confirm your email address", "hey i need u to send some $"], model="YOUR-FINE-TUNED-MODEL-ID", ) asyncio.run(main()) """ _response = await self._raw_client.classify( inputs=inputs, examples=examples, model=model, preset=preset, truncate=truncate, request_options=request_options, ) return _response.data async def summarize( self, *, text: str, length: typing.Optional[SummarizeRequestLength] = OMIT, format: typing.Optional[SummarizeRequestFormat] = OMIT, model: typing.Optional[str] = OMIT, extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT, temperature: typing.Optional[float] = OMIT, additional_command: typing.Optional[str] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> SummarizeResponse: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates a summary in English for a given text. Parameters ---------- text : str The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English. length : typing.Optional[SummarizeRequestLength] One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. format : typing.Optional[SummarizeRequestFormat] One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. model : typing.Optional[str] The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. extractiveness : typing.Optional[SummarizeRequestExtractiveness] One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. temperature : typing.Optional[float] Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1. additional_command : typing.Optional[str] A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda" request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- SummarizeResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.summarize( text='Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name "ice cream" varies from one country to another. In some countries, such as the United States, "ice cream" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled "frozen dairy dessert" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.', ) asyncio.run(main()) """ _response = await self._raw_client.summarize( text=text, length=length, format=format, model=model, extractiveness=extractiveness, temperature=temperature, additional_command=additional_command, request_options=request_options, ) return _response.data async def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None ) -> TokenizeResponse: """ This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- text : str The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters. model : str The input will be tokenized by the tokenizer that is used by this model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- TokenizeResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.tokenize( text="tokenize me! :D", model="command", ) asyncio.run(main()) """ _response = await self._raw_client.tokenize(text=text, model=model, request_options=request_options) return _response.data async def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None ) -> DetokenizeResponse: """ This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- tokens : typing.Sequence[int] The list of tokens to be detokenized. model : str An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DetokenizeResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.detokenize( tokens=[10002, 2261, 2012, 8, 2792, 43], model="command", ) asyncio.run(main()) """ _response = await self._raw_client.detokenize(tokens=tokens, model=model, request_options=request_options) return _response.data async def check_api_key(self, *, request_options: typing.Optional[RequestOptions] = None) -> CheckApiKeyResponse: """ Checks that the api key in the Authorization header is valid and active Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CheckApiKeyResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.check_api_key() asyncio.run(main()) """ _response = await self._raw_client.check_api_key(request_options=request_options) return _response.data @property def v2(self): if self._v2 is None: from .v2.client import AsyncV2Client # noqa: E402 self._v2 = AsyncV2Client(client_wrapper=self._client_wrapper) return self._v2 @property def batches(self): if self._batches is None: from .batches.client import AsyncBatchesClient # noqa: E402 self._batches = AsyncBatchesClient(client_wrapper=self._client_wrapper) return self._batches @property def embed_jobs(self): if self._embed_jobs is None: from .embed_jobs.client import AsyncEmbedJobsClient # noqa: E402 self._embed_jobs = AsyncEmbedJobsClient(client_wrapper=self._client_wrapper) return self._embed_jobs @property def datasets(self): if self._datasets is None: from .datasets.client import AsyncDatasetsClient # noqa: E402 self._datasets = AsyncDatasetsClient(client_wrapper=self._client_wrapper) return self._datasets @property def connectors(self): if self._connectors is None: from .connectors.client import AsyncConnectorsClient # noqa: E402 self._connectors = AsyncConnectorsClient(client_wrapper=self._client_wrapper) return self._connectors @property def models(self): if self._models is None: from .models.client import AsyncModelsClient # noqa: E402 self._models = AsyncModelsClient(client_wrapper=self._client_wrapper) return self._models @property def finetuning(self): if self._finetuning is None: from .finetuning.client import AsyncFinetuningClient # noqa: E402 self._finetuning = AsyncFinetuningClient(client_wrapper=self._client_wrapper) return self._finetuning @property def audio(self): if self._audio is None: from .audio.client import AsyncAudioClient # noqa: E402 self._audio = AsyncAudioClient(client_wrapper=self._client_wrapper) return self._audio def _get_base_url(*, base_url: typing.Optional[str] = None, environment: ClientEnvironment) -> str: if base_url is not None: return base_url elif environment is not None: return environment.value else: raise Exception("Please pass in either base_url or environment to construct the client") ================================================ FILE: src/cohere/batches/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import ( Batch, BatchStatus, CancelBatchResponse, CreateBatchResponse, GetBatchResponse, ListBatchesResponse, ) _dynamic_imports: typing.Dict[str, str] = { "Batch": ".types", "BatchStatus": ".types", "CancelBatchResponse": ".types", "CreateBatchResponse": ".types", "GetBatchResponse": ".types", "ListBatchesResponse": ".types", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "Batch", "BatchStatus", "CancelBatchResponse", "CreateBatchResponse", "GetBatchResponse", "ListBatchesResponse", ] ================================================ FILE: src/cohere/batches/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from .raw_client import AsyncRawBatchesClient, RawBatchesClient from .types.batch import Batch from .types.cancel_batch_response import CancelBatchResponse from .types.create_batch_response import CreateBatchResponse from .types.get_batch_response import GetBatchResponse from .types.list_batches_response import ListBatchesResponse # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class BatchesClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawBatchesClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawBatchesClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawBatchesClient """ return self._raw_client def list( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListBatchesResponse: """ List the batches for the current user Parameters ---------- page_size : typing.Optional[int] The maximum number of batches to return. The service may return fewer than this value. If unspecified, at most 50 batches will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. page_token : typing.Optional[str] A page token, received from a previous `ListBatches` call. Provide this to retrieve the subsequent page. order_by : typing.Optional[str] Batches can be ordered by creation time or last updated time. Use `created_at` for creation time or `updated_at` for last updated time. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListBatchesResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.batches.list( page_size=1, page_token="page_token", order_by="order_by", ) """ _response = self._raw_client.list( page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options ) return _response.data def create(self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None) -> CreateBatchResponse: """ Creates and executes a batch from an uploaded dataset of requests Parameters ---------- request : Batch request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateBatchResponse A successful response. Examples -------- from cohere import Client from cohere.batches import Batch client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.batches.create( request=Batch( name="name", input_dataset_id="input_dataset_id", model="model", ), ) """ _response = self._raw_client.create(request=request, request_options=request_options) return _response.data def retrieve(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetBatchResponse: """ Retrieves a batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetBatchResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.batches.retrieve( id="id", ) """ _response = self._raw_client.retrieve(id, request_options=request_options) return _response.data def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> CancelBatchResponse: """ Cancels an in-progress batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CancelBatchResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.batches.cancel( id="id", ) """ _response = self._raw_client.cancel(id, request_options=request_options) return _response.data class AsyncBatchesClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawBatchesClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawBatchesClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawBatchesClient """ return self._raw_client async def list( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListBatchesResponse: """ List the batches for the current user Parameters ---------- page_size : typing.Optional[int] The maximum number of batches to return. The service may return fewer than this value. If unspecified, at most 50 batches will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. page_token : typing.Optional[str] A page token, received from a previous `ListBatches` call. Provide this to retrieve the subsequent page. order_by : typing.Optional[str] Batches can be ordered by creation time or last updated time. Use `created_at` for creation time or `updated_at` for last updated time. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListBatchesResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.batches.list( page_size=1, page_token="page_token", order_by="order_by", ) asyncio.run(main()) """ _response = await self._raw_client.list( page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options ) return _response.data async def create( self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None ) -> CreateBatchResponse: """ Creates and executes a batch from an uploaded dataset of requests Parameters ---------- request : Batch request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateBatchResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient from cohere.batches import Batch client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.batches.create( request=Batch( name="name", input_dataset_id="input_dataset_id", model="model", ), ) asyncio.run(main()) """ _response = await self._raw_client.create(request=request, request_options=request_options) return _response.data async def retrieve(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetBatchResponse: """ Retrieves a batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetBatchResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.batches.retrieve( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.retrieve(id, request_options=request_options) return _response.data async def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> CancelBatchResponse: """ Cancels an in-progress batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CancelBatchResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.batches.cancel( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.cancel(id, request_options=request_options) return _response.data ================================================ FILE: src/cohere/batches/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.serialization import convert_and_respect_annotation_metadata from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.internal_server_error import InternalServerError from ..errors.not_found_error import NotFoundError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.unauthorized_error import UnauthorizedError from .types.batch import Batch from .types.cancel_batch_response import CancelBatchResponse from .types.create_batch_response import CreateBatchResponse from .types.get_batch_response import GetBatchResponse from .types.list_batches_response import ListBatchesResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawBatchesClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def list( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListBatchesResponse]: """ List the batches for the current user Parameters ---------- page_size : typing.Optional[int] The maximum number of batches to return. The service may return fewer than this value. If unspecified, at most 50 batches will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. page_token : typing.Optional[str] A page token, received from a previous `ListBatches` call. Provide this to retrieve the subsequent page. order_by : typing.Optional[str] Batches can be ordered by creation time or last updated time. Use `created_at` for creation time or `updated_at` for last updated time. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListBatchesResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v2/batches", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListBatchesResponse, construct_type( type_=ListBatchesResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def create( self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[CreateBatchResponse]: """ Creates and executes a batch from an uploaded dataset of requests Parameters ---------- request : Batch request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CreateBatchResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v2/batches", method="POST", json=convert_and_respect_annotation_metadata(object_=request, annotation=Batch, direction="write"), headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateBatchResponse, construct_type( type_=CreateBatchResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def retrieve( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[GetBatchResponse]: """ Retrieves a batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[GetBatchResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v2/batches/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetBatchResponse, construct_type( type_=GetBatchResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def cancel( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[CancelBatchResponse]: """ Cancels an in-progress batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CancelBatchResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v2/batches/{jsonable_encoder(id)}:cancel", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CancelBatchResponse, construct_type( type_=CancelBatchResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawBatchesClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def list( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListBatchesResponse]: """ List the batches for the current user Parameters ---------- page_size : typing.Optional[int] The maximum number of batches to return. The service may return fewer than this value. If unspecified, at most 50 batches will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. page_token : typing.Optional[str] A page token, received from a previous `ListBatches` call. Provide this to retrieve the subsequent page. order_by : typing.Optional[str] Batches can be ordered by creation time or last updated time. Use `created_at` for creation time or `updated_at` for last updated time. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListBatchesResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v2/batches", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListBatchesResponse, construct_type( type_=ListBatchesResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def create( self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[CreateBatchResponse]: """ Creates and executes a batch from an uploaded dataset of requests Parameters ---------- request : Batch request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CreateBatchResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v2/batches", method="POST", json=convert_and_respect_annotation_metadata(object_=request, annotation=Batch, direction="write"), headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateBatchResponse, construct_type( type_=CreateBatchResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def retrieve( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[GetBatchResponse]: """ Retrieves a batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[GetBatchResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v2/batches/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetBatchResponse, construct_type( type_=GetBatchResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def cancel( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[CancelBatchResponse]: """ Cancels an in-progress batch Parameters ---------- id : str The batch ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CancelBatchResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v2/batches/{jsonable_encoder(id)}:cancel", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CancelBatchResponse, construct_type( type_=CancelBatchResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/batches/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .batch import Batch from .batch_status import BatchStatus from .cancel_batch_response import CancelBatchResponse from .create_batch_response import CreateBatchResponse from .get_batch_response import GetBatchResponse from .list_batches_response import ListBatchesResponse _dynamic_imports: typing.Dict[str, str] = { "Batch": ".batch", "BatchStatus": ".batch_status", "CancelBatchResponse": ".cancel_batch_response", "CreateBatchResponse": ".create_batch_response", "GetBatchResponse": ".get_batch_response", "ListBatchesResponse": ".list_batches_response", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "Batch", "BatchStatus", "CancelBatchResponse", "CreateBatchResponse", "GetBatchResponse", "ListBatchesResponse", ] ================================================ FILE: src/cohere/batches/types/batch.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from .batch_status import BatchStatus class Batch(UncheckedBaseModel): """ This resource represents a batch job. """ id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. Batch ID. """ name: str = pydantic.Field() """ Batch name (e.g. `foobar`). """ creator_id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. User ID of the creator. """ org_id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. Organization ID. """ status: typing.Optional[BatchStatus] = pydantic.Field(default=None) """ read-only. Current stage in the life-cycle of the batch. """ created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Creation timestamp. """ updated_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Latest update timestamp. """ input_dataset_id: str = pydantic.Field() """ ID of the dataset the batch reads inputs from. """ output_dataset_id: typing.Optional[str] = None input_tokens: typing.Optional[str] = pydantic.Field(default=None) """ read-only. The total number of input tokens in the batch. """ output_tokens: typing.Optional[str] = pydantic.Field(default=None) """ read-only. The total number of output tokens in the batch. """ model: str = pydantic.Field() """ The name of the model the batch uses. """ num_records: typing.Optional[int] = pydantic.Field(default=None) """ read-only. The total number of records in the batch. """ num_successful_records: typing.Optional[int] = pydantic.Field(default=None) """ read-only. The current number of successful records in the batch. """ num_failed_records: typing.Optional[int] = pydantic.Field(default=None) """ read-only. The current number of failed records in the batch. """ status_reason: typing.Optional[str] = pydantic.Field(default=None) """ read-only. More details about the reason for the status of a batch job. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/batches/types/batch_status.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing BatchStatus = typing.Union[ typing.Literal[ "BATCH_STATUS_UNSPECIFIED", "BATCH_STATUS_QUEUED", "BATCH_STATUS_IN_PROGRESS", "BATCH_STATUS_CANCELING", "BATCH_STATUS_COMPLETED", "BATCH_STATUS_FAILED", "BATCH_STATUS_CANCELED", ], typing.Any, ] ================================================ FILE: src/cohere/batches/types/cancel_batch_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing CancelBatchResponse = typing.Dict[str, typing.Any] """ Response to a request to cancel a batch. """ ================================================ FILE: src/cohere/batches/types/create_batch_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from .batch import Batch class CreateBatchResponse(UncheckedBaseModel): """ Response to request to create a batch. """ batch: Batch = pydantic.Field() """ Information about the batch. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/batches/types/get_batch_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from .batch import Batch class GetBatchResponse(UncheckedBaseModel): """ Response to a request to get a batch. """ batch: Batch = pydantic.Field() """ Information about the batch. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/batches/types/list_batches_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from .batch import Batch class ListBatchesResponse(UncheckedBaseModel): """ Response to a request to list batches. """ batches: typing.Optional[typing.List[Batch]] = pydantic.Field(default=None) """ The batches that belong to the authenticated user. """ next_page_token: typing.Optional[str] = pydantic.Field(default=None) """ A token, which can be sent as `page_token` to retrieve the next page. If this field is omitted, there are no subsequent pages. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/bedrock_client.py ================================================ import typing from tokenizers import Tokenizer # type: ignore from .aws_client import AwsClient, AwsClientV2 class BedrockClient(AwsClient): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, ): AwsClient.__init__( self, service="bedrock", aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, timeout=timeout, ) def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None): raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.") class BedrockClientV2(AwsClientV2): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, ): AwsClientV2.__init__( self, service="bedrock", aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, timeout=timeout, ) ================================================ FILE: src/cohere/client.py ================================================ import asyncio import os import typing from concurrent.futures import ThreadPoolExecutor from tokenizers import Tokenizer # type: ignore import logging import httpx from cohere.types.detokenize_response import DetokenizeResponse from cohere.types.tokenize_response import TokenizeResponse from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate from .base_client import BaseCohere, AsyncBaseCohere, OMIT from .config import embed_batch_size, embed_stream_batch_size from .core import RequestOptions from .environment import ClientEnvironment from .manually_maintained.cache import CacheMixin from .manually_maintained import tokenizers as local_tokenizers from .overrides import run_overrides from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils logger = logging.getLogger(__name__) run_overrides() # Use NoReturn as Never type for compatibility Never = typing.NoReturn def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: method = getattr(obj, method_name) def _wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_fn(*args, **kwargs) return method(*args, **kwargs) async def _async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: # The `return await` looks redundant, but it's necessary to ensure that the return type is correct. check_fn(*args, **kwargs) return await method(*args, **kwargs) wrapped = _wrapped if asyncio.iscoroutinefunction(method): wrapped = _async_wrapped wrapped.__name__ = method.__name__ wrapped.__doc__ = method.__doc__ setattr(obj, method_name, wrapped) def throw_if_stream_is_true(*args, **kwargs) -> None: if kwargs.get("stream") is True: raise ValueError( "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" ) def moved_function(fn_name: str, new_fn_name: str) -> typing.Any: """ This method is moved. Please update usage. """ def fn(*args, **kwargs): raise ValueError( f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). " f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." ) return fn def deprecated_function(fn_name: str) -> typing.Any: """ This method is deprecated. Please update usage. """ def fn(*args, **kwargs): raise ValueError( f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. " f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." ) return fn # Logs a warning when a user calls a function with an experimental parameter (kwarg in our case) # `deprecated_kwarg` is the name of the experimental parameter, which can be a dot-separated string for nested parameters def experimental_kwarg_decorator(func, deprecated_kwarg): # Recursive utility function to check if a kwarg is present in the kwargs. def check_kwarg(deprecated_kwarg: str, kwargs: typing.Dict[str, typing.Any]) -> bool: if "." in deprecated_kwarg: key, rest = deprecated_kwarg.split(".", 1) if key in kwargs: return check_kwarg(rest, kwargs[key]) return deprecated_kwarg in kwargs def _wrapped(*args, **kwargs): if check_kwarg(deprecated_kwarg, kwargs): logger.warning( f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n" "To suppress this warning, set `log_warning_experimental_features=False` when initializing the client." ) return func(*args, **kwargs) async def _async_wrapped(*args, **kwargs): if check_kwarg(deprecated_kwarg, kwargs): logger.warning( f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n" "To suppress this warning, set `log_warning_experimental_features=False` when initializing the client." ) return await func(*args, **kwargs) wrap = _wrapped if asyncio.iscoroutinefunction(func): wrap = _async_wrapped wrap.__name__ = func.__name__ wrap.__doc__ = func.__doc__ return wrap def fix_base_url(base_url: typing.Optional[str]) -> typing.Optional[str]: if base_url is not None: if "cohere.com" in base_url or "cohere.ai" in base_url: return base_url.replace("/v1", "") return base_url return None class Client(BaseCohere, CacheMixin): _executor: ThreadPoolExecutor def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.Client] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): if api_key is None: api_key = _get_api_key_from_environment() base_url = fix_base_url(base_url) self._executor = thread_pool_executor BaseCohere.__init__( self, base_url=base_url, environment=environment, client_name=client_name, token=api_key, timeout=timeout, httpx_client=httpx_client, ) validate_args(self, "chat", throw_if_stream_is_true) if log_warning_experimental_features: self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore utils = SyncSdkUtils() # support context manager until Fern upstreams # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self._client_wrapper.httpx_client.httpx_client.close() wait = wait def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: return BaseCohere.embed( self, texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] responses = [ response for response in self._executor.map( lambda text_batch: BaseCohere.embed( self, texts=text_batch, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ), texts_batches, ) ] return merge_embed_responses(responses) def embed_stream( self, *, texts: typing.Sequence[str], model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, batch_size: int = embed_stream_batch_size, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[typing.Any]: """ Memory-efficient embed that yields embeddings one batch at a time. Processes texts in batches and yields individual StreamedEmbedding objects as they come back, so you can write to a vector store incrementally without holding all embeddings in memory. Args: texts: Texts to embed. model: Embedding model ID. input_type: Input type (search_document, search_query, etc.). embedding_types: Types of embeddings to return (float, int8, etc.). truncate: How to handle inputs longer than the max token length. batch_size: Texts per API call. Defaults to 96 (API max). request_options: Request-specific configuration. Yields: StreamedEmbedding with index, embedding, embedding_type, and text. """ from .manually_maintained.streaming_embed import extract_embeddings_from_response if not texts: return if batch_size < 1: raise ValueError("batch_size must be at least 1") texts_list = list(texts) for batch_start in range(0, len(texts_list), batch_size): batch_texts = texts_list[batch_start : batch_start + batch_size] response = BaseCohere.embed( self, texts=batch_texts, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) response_data = response.dict() if hasattr(response, "dict") else response.__dict__ yield from extract_embeddings_from_response(response_data, batch_texts, batch_start) """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. """ check_api_key: Never = deprecated_function("check_api_key") loglikelihood: Never = deprecated_function("loglikelihood") batch_generate: Never = deprecated_function("batch_generate") codebook: Never = deprecated_function("codebook") batch_tokenize: Never = deprecated_function("batch_tokenize") batch_detokenize: Never = deprecated_function("batch_detokenize") detect_language: Never = deprecated_function("detect_language") generate_feedback: Never = deprecated_function("generate_feedback") generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") create_dataset: Never = moved_function("create_dataset", ".datasets.create") get_dataset: Never = moved_function("get_dataset", ".datasets.get") list_datasets: Never = moved_function("list_datasets", ".datasets.list") delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") _check_response: Never = deprecated_function("_check_response") _request: Never = deprecated_function("_request") create_cluster_job: Never = deprecated_function("create_cluster_job") get_cluster_job: Never = deprecated_function("get_cluster_job") list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") create_custom_model: Never = deprecated_function("create_custom_model") wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") _upload_dataset: Never = deprecated_function("_upload_dataset") _create_signed_url: Never = deprecated_function("_create_signed_url") get_custom_model: Never = deprecated_function("get_custom_model") get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") list_custom_models: Never = deprecated_function("list_custom_models") create_connector: Never = moved_function("create_connector", ".connectors.create") update_connector: Never = moved_function("update_connector", ".connectors.update") get_connector: Never = moved_function("get_connector", ".connectors.get") list_connectors: Never = moved_function("list_connectors", ".connectors.list") delete_connector: Never = moved_function("delete_connector", ".connectors.delete") oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None, offline: bool = True, ) -> TokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: tokens = local_tokenizers.local_tokenize(self, text=text, model=model) return TokenizeResponse(tokens=tokens, token_strings=[]) except Exception: # Fallback to calling the API. opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return super().tokenize(text=text, model=model, request_options=opts) def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> DetokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: text = local_tokenizers.local_detokenize(self, model=model, tokens=tokens) return DetokenizeResponse(text=text) except Exception: # Fallback to calling the API. opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return super().detokenize(tokens=tokens, model=model, request_options=opts) def fetch_tokenizer(self, *, model: str) -> Tokenizer: """ Returns a Hugging Face tokenizer from a given model name. """ return local_tokenizers.get_hf_tokenizer(self, model) class AsyncClient(AsyncBaseCohere, CacheMixin): _executor: ThreadPoolExecutor def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.AsyncClient] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): if api_key is None: api_key = _get_api_key_from_environment() base_url = fix_base_url(base_url) self._executor = thread_pool_executor AsyncBaseCohere.__init__( self, base_url=base_url, environment=environment, client_name=client_name, token=api_key, timeout=timeout, httpx_client=httpx_client, ) validate_args(self, "chat", throw_if_stream_is_true) if log_warning_experimental_features: self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore utils = AsyncSdkUtils() # support context manager until Fern upstreams # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self._client_wrapper.httpx_client.httpx_client.aclose() wait = async_wait async def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: return await AsyncBaseCohere.embed( self, texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] responses = typing.cast( typing.List[EmbedResponse], await asyncio.gather( *[ AsyncBaseCohere.embed( self, texts=text_batch, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) for text_batch in texts_batches ] ), ) return merge_embed_responses(responses) """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. """ check_api_key: Never = deprecated_function("check_api_key") loglikelihood: Never = deprecated_function("loglikelihood") batch_generate: Never = deprecated_function("batch_generate") codebook: Never = deprecated_function("codebook") batch_tokenize: Never = deprecated_function("batch_tokenize") batch_detokenize: Never = deprecated_function("batch_detokenize") detect_language: Never = deprecated_function("detect_language") generate_feedback: Never = deprecated_function("generate_feedback") generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") create_dataset: Never = moved_function("create_dataset", ".datasets.create") get_dataset: Never = moved_function("get_dataset", ".datasets.get") list_datasets: Never = moved_function("list_datasets", ".datasets.list") delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") _check_response: Never = deprecated_function("_check_response") _request: Never = deprecated_function("_request") create_cluster_job: Never = deprecated_function("create_cluster_job") get_cluster_job: Never = deprecated_function("get_cluster_job") list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") create_custom_model: Never = deprecated_function("create_custom_model") wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") _upload_dataset: Never = deprecated_function("_upload_dataset") _create_signed_url: Never = deprecated_function("_create_signed_url") get_custom_model: Never = deprecated_function("get_custom_model") get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") list_custom_models: Never = deprecated_function("list_custom_models") create_connector: Never = moved_function("create_connector", ".connectors.create") update_connector: Never = moved_function("update_connector", ".connectors.update") get_connector: Never = moved_function("get_connector", ".connectors.get") list_connectors: Never = moved_function("list_connectors", ".connectors.list") delete_connector: Never = moved_function("delete_connector", ".connectors.delete") oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") async def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> TokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text) return TokenizeResponse(tokens=tokens, token_strings=[]) except Exception: opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return await super().tokenize(text=text, model=model, request_options=opts) async def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> DetokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: text = await local_tokenizers.async_local_detokenize(self, model=model, tokens=tokens) return DetokenizeResponse(text=text) except Exception: opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return await super().detokenize(tokens=tokens, model=model, request_options=opts) async def fetch_tokenizer(self, *, model: str) -> Tokenizer: """ Returns a Hugging Face tokenizer from a given model name. """ return await local_tokenizers.async_get_hf_tokenizer(self, model) def _get_api_key_from_environment() -> typing.Optional[str]: """ Retrieves the Cohere API key from specific environment variables. CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented). """ return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY")) ================================================ FILE: src/cohere/client_v2.py ================================================ import os import typing from concurrent.futures import ThreadPoolExecutor import httpx from .client import AsyncClient, Client from .environment import ClientEnvironment from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client class _CombinedRawClient: """Proxy that combines v1 and v2 raw clients. V2Client and Client both assign to self._raw_client in __init__, causing a collision when combined in ClientV2/AsyncClientV2. This proxy delegates to v2 first, falling back to v1 for legacy methods like generate_stream. """ def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any): self._v1 = v1_raw_client self._v2 = v2_raw_client def __getattr__(self, name: str) -> typing.Any: try: return getattr(self._v2, name) except AttributeError: return getattr(self._v1, name) class ClientV2(V2Client, Client): # type: ignore def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.Client] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): Client.__init__( self, api_key=api_key, base_url=base_url, environment=environment, client_name=client_name, timeout=timeout, httpx_client=httpx_client, thread_pool_executor=thread_pool_executor, log_warning_experimental_features=log_warning_experimental_features, ) v1_raw = self._raw_client V2Client.__init__( self, client_wrapper=self._client_wrapper ) self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client)) class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.AsyncClient] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): AsyncClient.__init__( self, api_key=api_key, base_url=base_url, environment=environment, client_name=client_name, timeout=timeout, httpx_client=httpx_client, thread_pool_executor=thread_pool_executor, log_warning_experimental_features=log_warning_experimental_features, ) v1_raw = self._raw_client AsyncV2Client.__init__( self, client_wrapper=self._client_wrapper ) self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client)) ================================================ FILE: src/cohere/config.py ================================================ embed_batch_size = 96 embed_stream_batch_size = 96 # Max texts per API request (API limit) ================================================ FILE: src/cohere/connectors/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file ================================================ FILE: src/cohere/connectors/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from ..types.create_connector_o_auth import CreateConnectorOAuth from ..types.create_connector_response import CreateConnectorResponse from ..types.create_connector_service_auth import CreateConnectorServiceAuth from ..types.delete_connector_response import DeleteConnectorResponse from ..types.get_connector_response import GetConnectorResponse from ..types.list_connectors_response import ListConnectorsResponse from ..types.o_auth_authorize_response import OAuthAuthorizeResponse from ..types.update_connector_response import UpdateConnectorResponse from .raw_client import AsyncRawConnectorsClient, RawConnectorsClient # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class ConnectorsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawConnectorsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawConnectorsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawConnectorsClient """ return self._raw_client def list( self, *, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListConnectorsResponse: """ Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- limit : typing.Optional[float] Maximum number of connectors to return [0, 100]. offset : typing.Optional[float] Number of connectors to skip before returning results [0, inf]. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListConnectorsResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.list( limit=1.1, offset=1.1, ) """ _response = self._raw_client.list(limit=limit, offset=offset, request_options=request_options) return _response.data def create( self, *, name: str, url: str, description: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> CreateConnectorResponse: """ Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information. Parameters ---------- name : str A human-readable name for the connector. url : str The URL of the connector that will be used to search for documents. description : typing.Optional[str] A description of the connector. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] Whether the connector is active or not. continue_on_failure : typing.Optional[bool] Whether a chat request should continue or not if the request to this connector fails. service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateConnectorResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.create( name="name", url="url", ) """ _response = self._raw_client.create( name=name, url=url, description=description, excludes=excludes, oauth=oauth, active=active, continue_on_failure=continue_on_failure, service_auth=service_auth, request_options=request_options, ) return _response.data def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetConnectorResponse: """ Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetConnectorResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.get( id="id", ) """ _response = self._raw_client.get(id, request_options=request_options) return _response.data def delete(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DeleteConnectorResponse: """ Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to delete. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DeleteConnectorResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.delete( id="id", ) """ _response = self._raw_client.delete(id, request_options=request_options) return _response.data def update( self, id: str, *, name: typing.Optional[str] = OMIT, url: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> UpdateConnectorResponse: """ Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- id : str The ID of the connector to update. name : typing.Optional[str] A human-readable name for the connector. url : typing.Optional[str] The URL of the connector that will be used to search for documents. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] continue_on_failure : typing.Optional[bool] service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- UpdateConnectorResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.update( id="id", ) """ _response = self._raw_client.update( id, name=name, url=url, excludes=excludes, oauth=oauth, active=active, continue_on_failure=continue_on_failure, service_auth=service_auth, request_options=request_options, ) return _response.data def o_auth_authorize( self, id: str, *, after_token_redirect: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> OAuthAuthorizeResponse: """ Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. Parameters ---------- id : str The ID of the connector to authorize. after_token_redirect : typing.Optional[str] The URL to redirect to after the connector has been authorized. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- OAuthAuthorizeResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.connectors.o_auth_authorize( id="id", after_token_redirect="after_token_redirect", ) """ _response = self._raw_client.o_auth_authorize( id, after_token_redirect=after_token_redirect, request_options=request_options ) return _response.data class AsyncConnectorsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawConnectorsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawConnectorsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawConnectorsClient """ return self._raw_client async def list( self, *, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListConnectorsResponse: """ Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- limit : typing.Optional[float] Maximum number of connectors to return [0, 100]. offset : typing.Optional[float] Number of connectors to skip before returning results [0, inf]. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListConnectorsResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.list( limit=1.1, offset=1.1, ) asyncio.run(main()) """ _response = await self._raw_client.list(limit=limit, offset=offset, request_options=request_options) return _response.data async def create( self, *, name: str, url: str, description: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> CreateConnectorResponse: """ Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information. Parameters ---------- name : str A human-readable name for the connector. url : str The URL of the connector that will be used to search for documents. description : typing.Optional[str] A description of the connector. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] Whether the connector is active or not. continue_on_failure : typing.Optional[bool] Whether a chat request should continue or not if the request to this connector fails. service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateConnectorResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.create( name="name", url="url", ) asyncio.run(main()) """ _response = await self._raw_client.create( name=name, url=url, description=description, excludes=excludes, oauth=oauth, active=active, continue_on_failure=continue_on_failure, service_auth=service_auth, request_options=request_options, ) return _response.data async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetConnectorResponse: """ Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetConnectorResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.get( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.get(id, request_options=request_options) return _response.data async def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> DeleteConnectorResponse: """ Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to delete. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DeleteConnectorResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.delete( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.delete(id, request_options=request_options) return _response.data async def update( self, id: str, *, name: typing.Optional[str] = OMIT, url: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> UpdateConnectorResponse: """ Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- id : str The ID of the connector to update. name : typing.Optional[str] A human-readable name for the connector. url : typing.Optional[str] The URL of the connector that will be used to search for documents. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] continue_on_failure : typing.Optional[bool] service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- UpdateConnectorResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.update( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.update( id, name=name, url=url, excludes=excludes, oauth=oauth, active=active, continue_on_failure=continue_on_failure, service_auth=service_auth, request_options=request_options, ) return _response.data async def o_auth_authorize( self, id: str, *, after_token_redirect: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> OAuthAuthorizeResponse: """ Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. Parameters ---------- id : str The ID of the connector to authorize. after_token_redirect : typing.Optional[str] The URL to redirect to after the connector has been authorized. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- OAuthAuthorizeResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.connectors.o_auth_authorize( id="id", after_token_redirect="after_token_redirect", ) asyncio.run(main()) """ _response = await self._raw_client.o_auth_authorize( id, after_token_redirect=after_token_redirect, request_options=request_options ) return _response.data ================================================ FILE: src/cohere/connectors/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.serialization import convert_and_respect_annotation_metadata from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.client_closed_request_error import ClientClosedRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.gateway_timeout_error import GatewayTimeoutError from ..errors.internal_server_error import InternalServerError from ..errors.invalid_token_error import InvalidTokenError from ..errors.not_found_error import NotFoundError from ..errors.not_implemented_error import NotImplementedError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.too_many_requests_error import TooManyRequestsError from ..errors.unauthorized_error import UnauthorizedError from ..errors.unprocessable_entity_error import UnprocessableEntityError from ..types.create_connector_o_auth import CreateConnectorOAuth from ..types.create_connector_response import CreateConnectorResponse from ..types.create_connector_service_auth import CreateConnectorServiceAuth from ..types.delete_connector_response import DeleteConnectorResponse from ..types.get_connector_response import GetConnectorResponse from ..types.list_connectors_response import ListConnectorsResponse from ..types.o_auth_authorize_response import OAuthAuthorizeResponse from ..types.update_connector_response import UpdateConnectorResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawConnectorsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def list( self, *, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListConnectorsResponse]: """ Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- limit : typing.Optional[float] Maximum number of connectors to return [0, 100]. offset : typing.Optional[float] Number of connectors to skip before returning results [0, inf]. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListConnectorsResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/connectors", method="GET", params={ "limit": limit, "offset": offset, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListConnectorsResponse, construct_type( type_=ListConnectorsResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def create( self, *, name: str, url: str, description: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[CreateConnectorResponse]: """ Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information. Parameters ---------- name : str A human-readable name for the connector. url : str The URL of the connector that will be used to search for documents. description : typing.Optional[str] A description of the connector. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] Whether the connector is active or not. continue_on_failure : typing.Optional[bool] Whether a chat request should continue or not if the request to this connector fails. service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CreateConnectorResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/connectors", method="POST", json={ "name": name, "description": description, "url": url, "excludes": excludes, "oauth": convert_and_respect_annotation_metadata( object_=oauth, annotation=CreateConnectorOAuth, direction="write" ), "active": active, "continue_on_failure": continue_on_failure, "service_auth": convert_and_respect_annotation_metadata( object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateConnectorResponse, construct_type( type_=CreateConnectorResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def get( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[GetConnectorResponse]: """ Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[GetConnectorResponse] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetConnectorResponse, construct_type( type_=GetConnectorResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[DeleteConnectorResponse]: """ Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to delete. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DeleteConnectorResponse] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DeleteConnectorResponse, construct_type( type_=DeleteConnectorResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def update( self, id: str, *, name: typing.Optional[str] = OMIT, url: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[UpdateConnectorResponse]: """ Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- id : str The ID of the connector to update. name : typing.Optional[str] A human-readable name for the connector. url : typing.Optional[str] The URL of the connector that will be used to search for documents. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] continue_on_failure : typing.Optional[bool] service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[UpdateConnectorResponse] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="PATCH", json={ "name": name, "url": url, "excludes": excludes, "oauth": convert_and_respect_annotation_metadata( object_=oauth, annotation=CreateConnectorOAuth, direction="write" ), "active": active, "continue_on_failure": continue_on_failure, "service_auth": convert_and_respect_annotation_metadata( object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( UpdateConnectorResponse, construct_type( type_=UpdateConnectorResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def o_auth_authorize( self, id: str, *, after_token_redirect: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[OAuthAuthorizeResponse]: """ Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. Parameters ---------- id : str The ID of the connector to authorize. after_token_redirect : typing.Optional[str] The URL to redirect to after the connector has been authorized. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[OAuthAuthorizeResponse] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}/oauth/authorize", method="POST", params={ "after_token_redirect": after_token_redirect, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( OAuthAuthorizeResponse, construct_type( type_=OAuthAuthorizeResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawConnectorsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def list( self, *, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListConnectorsResponse]: """ Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- limit : typing.Optional[float] Maximum number of connectors to return [0, 100]. offset : typing.Optional[float] Number of connectors to skip before returning results [0, inf]. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListConnectorsResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/connectors", method="GET", params={ "limit": limit, "offset": offset, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListConnectorsResponse, construct_type( type_=ListConnectorsResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def create( self, *, name: str, url: str, description: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[CreateConnectorResponse]: """ Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information. Parameters ---------- name : str A human-readable name for the connector. url : str The URL of the connector that will be used to search for documents. description : typing.Optional[str] A description of the connector. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] Whether the connector is active or not. continue_on_failure : typing.Optional[bool] Whether a chat request should continue or not if the request to this connector fails. service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CreateConnectorResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/connectors", method="POST", json={ "name": name, "description": description, "url": url, "excludes": excludes, "oauth": convert_and_respect_annotation_metadata( object_=oauth, annotation=CreateConnectorOAuth, direction="write" ), "active": active, "continue_on_failure": continue_on_failure, "service_auth": convert_and_respect_annotation_metadata( object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateConnectorResponse, construct_type( type_=CreateConnectorResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def get( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[GetConnectorResponse]: """ Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[GetConnectorResponse] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetConnectorResponse, construct_type( type_=GetConnectorResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[DeleteConnectorResponse]: """ Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information. Parameters ---------- id : str The ID of the connector to delete. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DeleteConnectorResponse] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DeleteConnectorResponse, construct_type( type_=DeleteConnectorResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def update( self, id: str, *, name: typing.Optional[str] = OMIT, url: typing.Optional[str] = OMIT, excludes: typing.Optional[typing.Sequence[str]] = OMIT, oauth: typing.Optional[CreateConnectorOAuth] = OMIT, active: typing.Optional[bool] = OMIT, continue_on_failure: typing.Optional[bool] = OMIT, service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[UpdateConnectorResponse]: """ Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. Parameters ---------- id : str The ID of the connector to update. name : typing.Optional[str] A human-readable name for the connector. url : typing.Optional[str] The URL of the connector that will be used to search for documents. excludes : typing.Optional[typing.Sequence[str]] A list of fields to exclude from the prompt (fields remain in the document). oauth : typing.Optional[CreateConnectorOAuth] The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified. active : typing.Optional[bool] continue_on_failure : typing.Optional[bool] service_auth : typing.Optional[CreateConnectorServiceAuth] The service to service authentication configuration for the connector. Cannot be specified if oauth is specified. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[UpdateConnectorResponse] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}", method="PATCH", json={ "name": name, "url": url, "excludes": excludes, "oauth": convert_and_respect_annotation_metadata( object_=oauth, annotation=CreateConnectorOAuth, direction="write" ), "active": active, "continue_on_failure": continue_on_failure, "service_auth": convert_and_respect_annotation_metadata( object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( UpdateConnectorResponse, construct_type( type_=UpdateConnectorResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def o_auth_authorize( self, id: str, *, after_token_redirect: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[OAuthAuthorizeResponse]: """ Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. Parameters ---------- id : str The ID of the connector to authorize. after_token_redirect : typing.Optional[str] The URL to redirect to after the connector has been authorized. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[OAuthAuthorizeResponse] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/connectors/{jsonable_encoder(id)}/oauth/authorize", method="POST", params={ "after_token_redirect": after_token_redirect, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( OAuthAuthorizeResponse, construct_type( type_=OAuthAuthorizeResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/core/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .api_error import ApiError from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper from .datetime_utils import Rfc2822DateTime, parse_rfc2822_datetime, serialize_datetime from .file import File, convert_file_dict_to_httpx_tuples, with_content_type from .http_client import AsyncHttpClient, HttpClient from .http_response import AsyncHttpResponse, HttpResponse from .jsonable_encoder import jsonable_encoder from .logging import ConsoleLogger, ILogger, LogConfig, LogLevel, Logger, create_logger from .parse_error import ParsingError from .pydantic_utilities import ( IS_PYDANTIC_V2, UniversalBaseModel, UniversalRootModel, parse_obj_as, universal_field_validator, universal_root_validator, update_forward_refs, ) from .query_encoder import encode_query from .remove_none_from_dict import remove_none_from_dict from .request_options import RequestOptions from .serialization import FieldMetadata, convert_and_respect_annotation_metadata from .unchecked_base_model import UncheckedBaseModel, UnionMetadata, construct_type _dynamic_imports: typing.Dict[str, str] = { "ApiError": ".api_error", "AsyncClientWrapper": ".client_wrapper", "AsyncHttpClient": ".http_client", "AsyncHttpResponse": ".http_response", "BaseClientWrapper": ".client_wrapper", "ConsoleLogger": ".logging", "FieldMetadata": ".serialization", "File": ".file", "HttpClient": ".http_client", "HttpResponse": ".http_response", "ILogger": ".logging", "IS_PYDANTIC_V2": ".pydantic_utilities", "LogConfig": ".logging", "LogLevel": ".logging", "Logger": ".logging", "ParsingError": ".parse_error", "RequestOptions": ".request_options", "Rfc2822DateTime": ".datetime_utils", "SyncClientWrapper": ".client_wrapper", "UncheckedBaseModel": ".unchecked_base_model", "UnionMetadata": ".unchecked_base_model", "UniversalBaseModel": ".pydantic_utilities", "UniversalRootModel": ".pydantic_utilities", "construct_type": ".unchecked_base_model", "convert_and_respect_annotation_metadata": ".serialization", "convert_file_dict_to_httpx_tuples": ".file", "create_logger": ".logging", "encode_query": ".query_encoder", "jsonable_encoder": ".jsonable_encoder", "parse_obj_as": ".pydantic_utilities", "parse_rfc2822_datetime": ".datetime_utils", "remove_none_from_dict": ".remove_none_from_dict", "serialize_datetime": ".datetime_utils", "universal_field_validator": ".pydantic_utilities", "universal_root_validator": ".pydantic_utilities", "update_forward_refs": ".pydantic_utilities", "with_content_type": ".file", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "ApiError", "AsyncClientWrapper", "AsyncHttpClient", "AsyncHttpResponse", "BaseClientWrapper", "ConsoleLogger", "FieldMetadata", "File", "HttpClient", "HttpResponse", "ILogger", "IS_PYDANTIC_V2", "LogConfig", "LogLevel", "Logger", "ParsingError", "RequestOptions", "Rfc2822DateTime", "SyncClientWrapper", "UncheckedBaseModel", "UnionMetadata", "UniversalBaseModel", "UniversalRootModel", "construct_type", "convert_and_respect_annotation_metadata", "convert_file_dict_to_httpx_tuples", "create_logger", "encode_query", "jsonable_encoder", "parse_obj_as", "parse_rfc2822_datetime", "remove_none_from_dict", "serialize_datetime", "universal_field_validator", "universal_root_validator", "update_forward_refs", "with_content_type", ] ================================================ FILE: src/cohere/core/api_error.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Any, Dict, Optional class ApiError(Exception): headers: Optional[Dict[str, str]] status_code: Optional[int] body: Any def __init__( self, *, headers: Optional[Dict[str, str]] = None, status_code: Optional[int] = None, body: Any = None, ) -> None: self.headers = headers self.status_code = status_code self.body = body def __str__(self) -> str: return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}" ================================================ FILE: src/cohere/core/client_wrapper.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import httpx from .http_client import AsyncHttpClient, HttpClient from .logging import LogConfig, Logger class BaseClientWrapper: def __init__( self, *, client_name: typing.Optional[str] = None, token: typing.Union[str, typing.Callable[[], str]], headers: typing.Optional[typing.Dict[str, str]] = None, base_url: str, timeout: typing.Optional[float] = None, logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, ): self._client_name = client_name self._token = token self._headers = headers self._base_url = base_url self._timeout = timeout self._logging = logging def get_headers(self) -> typing.Dict[str, str]: import platform headers: typing.Dict[str, str] = { "User-Agent": "cohere/6.1.0", "X-Fern-Language": "Python", "X-Fern-Runtime": f"python/{platform.python_version()}", "X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}", "X-Fern-SDK-Name": "cohere", "X-Fern-SDK-Version": "6.1.0", **(self.get_custom_headers() or {}), } if self._client_name is not None: headers["X-Client-Name"] = self._client_name headers["Authorization"] = f"Bearer {self._get_token()}" return headers def _get_token(self) -> str: if isinstance(self._token, str): return self._token else: return self._token() def get_custom_headers(self) -> typing.Optional[typing.Dict[str, str]]: return self._headers def get_base_url(self) -> str: return self._base_url def get_timeout(self) -> typing.Optional[float]: return self._timeout class SyncClientWrapper(BaseClientWrapper): def __init__( self, *, client_name: typing.Optional[str] = None, token: typing.Union[str, typing.Callable[[], str]], headers: typing.Optional[typing.Dict[str, str]] = None, base_url: str, timeout: typing.Optional[float] = None, logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, httpx_client: httpx.Client, ): super().__init__( client_name=client_name, token=token, headers=headers, base_url=base_url, timeout=timeout, logging=logging ) self.httpx_client = HttpClient( httpx_client=httpx_client, base_headers=self.get_headers, base_timeout=self.get_timeout, base_url=self.get_base_url, logging_config=self._logging, ) class AsyncClientWrapper(BaseClientWrapper): def __init__( self, *, client_name: typing.Optional[str] = None, token: typing.Union[str, typing.Callable[[], str]], headers: typing.Optional[typing.Dict[str, str]] = None, base_url: str, timeout: typing.Optional[float] = None, logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, async_token: typing.Optional[typing.Callable[[], typing.Awaitable[str]]] = None, httpx_client: httpx.AsyncClient, ): super().__init__( client_name=client_name, token=token, headers=headers, base_url=base_url, timeout=timeout, logging=logging ) self._async_token = async_token self.httpx_client = AsyncHttpClient( httpx_client=httpx_client, base_headers=self.get_headers, base_timeout=self.get_timeout, base_url=self.get_base_url, async_base_headers=self.async_get_headers, logging_config=self._logging, ) async def async_get_headers(self) -> typing.Dict[str, str]: headers = self.get_headers() if self._async_token is not None: token = await self._async_token() headers["Authorization"] = f"Bearer {token}" return headers ================================================ FILE: src/cohere/core/datetime_utils.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt from email.utils import parsedate_to_datetime from typing import Any import pydantic IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.") def parse_rfc2822_datetime(v: Any) -> dt.datetime: """ Parse an RFC 2822 datetime string (e.g., "Wed, 02 Oct 2002 13:00:00 GMT") into a datetime object. If the value is already a datetime, return it as-is. Falls back to ISO 8601 parsing if RFC 2822 parsing fails. """ if isinstance(v, dt.datetime): return v if isinstance(v, str): try: return parsedate_to_datetime(v) except Exception: pass # Fallback to ISO 8601 parsing return dt.datetime.fromisoformat(v.replace("Z", "+00:00")) raise ValueError(f"Expected str or datetime, got {type(v)}") class Rfc2822DateTime(dt.datetime): """A datetime subclass that parses RFC 2822 date strings. On Pydantic V1, uses __get_validators__ for pre-validation. On Pydantic V2, uses __get_pydantic_core_schema__ for BeforeValidator-style parsing. """ @classmethod def __get_validators__(cls): # type: ignore[no-untyped-def] yield parse_rfc2822_datetime @classmethod def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Any) -> Any: # type: ignore[override] from pydantic_core import core_schema return core_schema.no_info_before_validator_function(parse_rfc2822_datetime, core_schema.datetime_schema()) def serialize_datetime(v: dt.datetime) -> str: """ Serialize a datetime including timezone info. Uses the timezone info provided if present, otherwise uses the current runtime's timezone info. UTC datetimes end in "Z" while all other timezones are represented as offset from UTC, e.g. +05:00. """ def _serialize_zoned_datetime(v: dt.datetime) -> str: if v.tzinfo is not None and v.tzinfo.tzname(None) == dt.timezone.utc.tzname(None): # UTC is a special case where we use "Z" at the end instead of "+00:00" return v.isoformat().replace("+00:00", "Z") else: # Delegate to the typical +/- offset format return v.isoformat() if v.tzinfo is not None: return _serialize_zoned_datetime(v) else: local_tz = dt.datetime.now().astimezone().tzinfo localized_dt = v.replace(tzinfo=local_tz) return _serialize_zoned_datetime(localized_dt) ================================================ FILE: src/cohere/core/file.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import IO, Dict, List, Mapping, Optional, Tuple, Union, cast # File typing inspired by the flexibility of types within the httpx library # https://github.com/encode/httpx/blob/master/httpx/_types.py FileContent = Union[IO[bytes], bytes, str] File = Union[ # file (or bytes) FileContent, # (filename, file (or bytes)) Tuple[Optional[str], FileContent], # (filename, file (or bytes), content_type) Tuple[Optional[str], FileContent, Optional[str]], # (filename, file (or bytes), content_type, headers) Tuple[ Optional[str], FileContent, Optional[str], Mapping[str, str], ], ] def convert_file_dict_to_httpx_tuples( d: Dict[str, Union[File, List[File]]], ) -> List[Tuple[str, File]]: """ The format we use is a list of tuples, where the first element is the name of the file and the second is the file object. Typically HTTPX wants a dict, but to be able to send lists of files, you have to use the list approach (which also works for non-lists) https://github.com/encode/httpx/pull/1032 """ httpx_tuples = [] for key, file_like in d.items(): if isinstance(file_like, list): for file_like_item in file_like: httpx_tuples.append((key, file_like_item)) else: httpx_tuples.append((key, file_like)) return httpx_tuples def with_content_type(*, file: File, default_content_type: str) -> File: """ This function resolves to the file's content type, if provided, and defaults to the default_content_type value if not. """ if isinstance(file, tuple): if len(file) == 2: filename, content = cast(Tuple[Optional[str], FileContent], file) # type: ignore return (filename, content, default_content_type) elif len(file) == 3: filename, content, file_content_type = cast(Tuple[Optional[str], FileContent, Optional[str]], file) # type: ignore out_content_type = file_content_type or default_content_type return (filename, content, out_content_type) elif len(file) == 4: filename, content, file_content_type, headers = cast( # type: ignore Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], file ) out_content_type = file_content_type or default_content_type return (filename, content, out_content_type, headers) else: raise ValueError(f"Unexpected tuple length: {len(file)}") return (None, file, default_content_type) ================================================ FILE: src/cohere/core/force_multipart.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Any, Dict class ForceMultipartDict(Dict[str, Any]): """ A dictionary subclass that always evaluates to True in boolean contexts. This is used to force multipart/form-data encoding in HTTP requests even when the dictionary is empty, which would normally evaluate to False. """ def __bool__(self) -> bool: return True FORCE_MULTIPART = ForceMultipartDict() ================================================ FILE: src/cohere/core/http_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import asyncio import email.utils import re import time import typing from contextlib import asynccontextmanager, contextmanager from random import random import httpx from .file import File, convert_file_dict_to_httpx_tuples from .force_multipart import FORCE_MULTIPART from .jsonable_encoder import jsonable_encoder from .logging import LogConfig, Logger, create_logger from .query_encoder import encode_query from .remove_none_from_dict import remove_none_from_dict as remove_none_from_dict from .request_options import RequestOptions from httpx._types import RequestFiles INITIAL_RETRY_DELAY_SECONDS = 1.0 MAX_RETRY_DELAY_SECONDS = 60.0 JITTER_FACTOR = 0.2 # 20% random jitter def _parse_retry_after(response_headers: httpx.Headers) -> typing.Optional[float]: """ This function parses the `Retry-After` header in a HTTP response and returns the number of seconds to wait. Inspired by the urllib3 retry implementation. """ retry_after_ms = response_headers.get("retry-after-ms") if retry_after_ms is not None: try: return int(retry_after_ms) / 1000 if retry_after_ms > 0 else 0 except Exception: pass retry_after = response_headers.get("retry-after") if retry_after is None: return None # Attempt to parse the header as an int. if re.match(r"^\s*[0-9]+\s*$", retry_after): seconds = float(retry_after) # Fallback to parsing it as a date. else: retry_date_tuple = email.utils.parsedate_tz(retry_after) if retry_date_tuple is None: return None if retry_date_tuple[9] is None: # Python 2 # Assume UTC if no timezone was specified # On Python2.7, parsedate_tz returns None for a timezone offset # instead of 0 if no timezone is given, where mktime_tz treats # a None timezone offset as local time. retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:] retry_date = email.utils.mktime_tz(retry_date_tuple) seconds = retry_date - time.time() if seconds < 0: seconds = 0 return seconds def _add_positive_jitter(delay: float) -> float: """Add positive jitter (0-20%) to prevent thundering herd.""" jitter_multiplier = 1 + random() * JITTER_FACTOR return delay * jitter_multiplier def _add_symmetric_jitter(delay: float) -> float: """Add symmetric jitter (±10%) for exponential backoff.""" jitter_multiplier = 1 + (random() - 0.5) * JITTER_FACTOR return delay * jitter_multiplier def _parse_x_ratelimit_reset(response_headers: httpx.Headers) -> typing.Optional[float]: """ Parse the X-RateLimit-Reset header (Unix timestamp in seconds). Returns seconds to wait, or None if header is missing/invalid. """ reset_time_str = response_headers.get("x-ratelimit-reset") if reset_time_str is None: return None try: reset_time = int(reset_time_str) delay = reset_time - time.time() if delay > 0: return delay except (ValueError, TypeError): pass return None def _retry_timeout(response: httpx.Response, retries: int) -> float: """ Determine the amount of time to wait before retrying a request. This function begins by trying to parse a retry-after header from the response, and then proceeds to use exponential backoff with a jitter to determine the number of seconds to wait. """ # 1. Check Retry-After header first retry_after = _parse_retry_after(response.headers) if retry_after is not None and retry_after > 0: return min(retry_after, MAX_RETRY_DELAY_SECONDS) # 2. Check X-RateLimit-Reset header (with positive jitter) ratelimit_reset = _parse_x_ratelimit_reset(response.headers) if ratelimit_reset is not None: return _add_positive_jitter(min(ratelimit_reset, MAX_RETRY_DELAY_SECONDS)) # 3. Fall back to exponential backoff (with symmetric jitter) backoff = min(INITIAL_RETRY_DELAY_SECONDS * pow(2.0, retries), MAX_RETRY_DELAY_SECONDS) return _add_symmetric_jitter(backoff) def _retry_timeout_from_retries(retries: int) -> float: """Determine retry timeout using exponential backoff when no response is available.""" backoff = min(INITIAL_RETRY_DELAY_SECONDS * pow(2.0, retries), MAX_RETRY_DELAY_SECONDS) return _add_symmetric_jitter(backoff) def _should_retry(response: httpx.Response) -> bool: retryable_400s = [429, 408, 409] return response.status_code >= 500 or response.status_code in retryable_400s _SENSITIVE_HEADERS = frozenset( { "authorization", "www-authenticate", "x-api-key", "api-key", "apikey", "x-api-token", "x-auth-token", "auth-token", "cookie", "set-cookie", "proxy-authorization", "proxy-authenticate", "x-csrf-token", "x-xsrf-token", "x-session-token", "x-access-token", } ) def _redact_headers(headers: typing.Dict[str, str]) -> typing.Dict[str, str]: return {k: ("[REDACTED]" if k.lower() in _SENSITIVE_HEADERS else v) for k, v in headers.items()} def _build_url(base_url: str, path: typing.Optional[str]) -> str: """ Build a full URL by joining a base URL with a path. This function correctly handles base URLs that contain path prefixes (e.g., tenant-based URLs) by using string concatenation instead of urllib.parse.urljoin(), which would incorrectly strip path components when the path starts with '/'. Example: >>> _build_url("https://cloud.example.com/org/tenant/api", "/users") 'https://cloud.example.com/org/tenant/api/users' Args: base_url: The base URL, which may contain path prefixes. path: The path to append. Can be None or empty string. Returns: The full URL with base_url and path properly joined. """ if not path: return base_url return f"{base_url.rstrip('/')}/{path.lstrip('/')}" def _maybe_filter_none_from_multipart_data( data: typing.Optional[typing.Any], request_files: typing.Optional[RequestFiles], force_multipart: typing.Optional[bool], ) -> typing.Optional[typing.Any]: """ Filter None values from data body for multipart/form requests. This prevents httpx from converting None to empty strings in multipart encoding. Only applies when files are present or force_multipart is True. """ if data is not None and isinstance(data, typing.Mapping) and (request_files or force_multipart): return remove_none_from_dict(data) return data def remove_omit_from_dict( original: typing.Dict[str, typing.Optional[typing.Any]], omit: typing.Optional[typing.Any], ) -> typing.Dict[str, typing.Any]: if omit is None: return original new: typing.Dict[str, typing.Any] = {} for key, value in original.items(): if value is not omit: new[key] = value return new def maybe_filter_request_body( data: typing.Optional[typing.Any], request_options: typing.Optional[RequestOptions], omit: typing.Optional[typing.Any], ) -> typing.Optional[typing.Any]: if data is None: return ( jsonable_encoder(request_options.get("additional_body_parameters", {})) or {} if request_options is not None else None ) elif not isinstance(data, typing.Mapping): data_content = jsonable_encoder(data) else: data_content = { **(jsonable_encoder(remove_omit_from_dict(data, omit))), # type: ignore **( jsonable_encoder(request_options.get("additional_body_parameters", {})) or {} if request_options is not None else {} ), } return data_content # Abstracted out for testing purposes def get_request_body( *, json: typing.Optional[typing.Any], data: typing.Optional[typing.Any], request_options: typing.Optional[RequestOptions], omit: typing.Optional[typing.Any], ) -> typing.Tuple[typing.Optional[typing.Any], typing.Optional[typing.Any]]: json_body = None data_body = None if data is not None: data_body = maybe_filter_request_body(data, request_options, omit) else: # If both data and json are None, we send json data in the event extra properties are specified json_body = maybe_filter_request_body(json, request_options, omit) has_additional_body_parameters = bool( request_options is not None and request_options.get("additional_body_parameters") ) # Only collapse empty dict to None when the body was not explicitly provided # and there are no additional body parameters. This preserves explicit empty # bodies (e.g., when an endpoint has a request body type but all fields are optional). if json_body == {} and json is None and not has_additional_body_parameters: json_body = None if data_body == {} and data is None and not has_additional_body_parameters: data_body = None return json_body, data_body class HttpClient: def __init__( self, *, httpx_client: httpx.Client, base_timeout: typing.Callable[[], typing.Optional[float]], base_headers: typing.Callable[[], typing.Dict[str, str]], base_url: typing.Optional[typing.Callable[[], str]] = None, base_max_retries: int = 2, logging_config: typing.Optional[typing.Union[LogConfig, Logger]] = None, ): self.base_url = base_url self.base_timeout = base_timeout self.base_headers = base_headers self.base_max_retries = base_max_retries self.httpx_client = httpx_client self.logger = create_logger(logging_config) def get_base_url(self, maybe_base_url: typing.Optional[str]) -> str: base_url = maybe_base_url if self.base_url is not None and base_url is None: base_url = self.base_url() if base_url is None: raise ValueError("A base_url is required to make this request, please provide one and try again.") return base_url def request( self, path: typing.Optional[str] = None, *, method: str, base_url: typing.Optional[str] = None, params: typing.Optional[typing.Dict[str, typing.Any]] = None, json: typing.Optional[typing.Any] = None, data: typing.Optional[typing.Any] = None, content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None, files: typing.Optional[ typing.Union[ typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]], typing.List[typing.Tuple[str, File]], ] ] = None, headers: typing.Optional[typing.Dict[str, typing.Any]] = None, request_options: typing.Optional[RequestOptions] = None, retries: int = 0, omit: typing.Optional[typing.Any] = None, force_multipart: typing.Optional[bool] = None, ) -> httpx.Response: base_url = self.get_base_url(base_url) timeout = ( request_options.get("timeout_in_seconds") if request_options is not None and request_options.get("timeout_in_seconds") is not None else self.base_timeout() ) json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit) request_files: typing.Optional[RequestFiles] = ( convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit)) if (files is not None and files is not omit and isinstance(files, dict)) else None ) if (request_files is None or len(request_files) == 0) and force_multipart: request_files = FORCE_MULTIPART data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart) # Compute encoded params separately to avoid passing empty list to httpx # (httpx strips existing query params from URL when params=[] is passed) _encoded_params = encode_query( jsonable_encoder( remove_none_from_dict( remove_omit_from_dict( { **(params if params is not None else {}), **( request_options.get("additional_query_parameters", {}) or {} if request_options is not None else {} ), }, omit, ) ) ) ) _request_url = _build_url(base_url, path) _request_headers = jsonable_encoder( remove_none_from_dict( { **self.base_headers(), **(headers if headers is not None else {}), **(request_options.get("additional_headers", {}) or {} if request_options is not None else {}), } ) ) if self.logger.is_debug(): self.logger.debug( "Making HTTP request", method=method, url=_request_url, headers=_redact_headers(_request_headers), has_body=json_body is not None or data_body is not None, ) max_retries: int = ( request_options.get("max_retries", self.base_max_retries) if request_options is not None else self.base_max_retries ) try: response = self.httpx_client.request( method=method, url=_request_url, headers=_request_headers, params=_encoded_params if _encoded_params else None, json=json_body, data=data_body, content=content, files=request_files, timeout=timeout, ) except (httpx.ConnectError, httpx.RemoteProtocolError): if retries < max_retries: time.sleep(_retry_timeout_from_retries(retries=retries)) return self.request( path=path, method=method, base_url=base_url, params=params, json=json, data=data, content=content, files=files, headers=headers, request_options=request_options, retries=retries + 1, omit=omit, force_multipart=force_multipart, ) raise if _should_retry(response=response): if retries < max_retries: time.sleep(_retry_timeout(response=response, retries=retries)) return self.request( path=path, method=method, base_url=base_url, params=params, json=json, data=data, content=content, files=files, headers=headers, request_options=request_options, retries=retries + 1, omit=omit, force_multipart=force_multipart, ) if self.logger.is_debug(): if 200 <= response.status_code < 400: self.logger.debug( "HTTP request succeeded", method=method, url=_request_url, status_code=response.status_code, ) if self.logger.is_error(): if response.status_code >= 400: self.logger.error( "HTTP request failed with error status", method=method, url=_request_url, status_code=response.status_code, ) return response @contextmanager def stream( self, path: typing.Optional[str] = None, *, method: str, base_url: typing.Optional[str] = None, params: typing.Optional[typing.Dict[str, typing.Any]] = None, json: typing.Optional[typing.Any] = None, data: typing.Optional[typing.Any] = None, content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None, files: typing.Optional[ typing.Union[ typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]], typing.List[typing.Tuple[str, File]], ] ] = None, headers: typing.Optional[typing.Dict[str, typing.Any]] = None, request_options: typing.Optional[RequestOptions] = None, retries: int = 0, omit: typing.Optional[typing.Any] = None, force_multipart: typing.Optional[bool] = None, ) -> typing.Iterator[httpx.Response]: base_url = self.get_base_url(base_url) timeout = ( request_options.get("timeout_in_seconds") if request_options is not None and request_options.get("timeout_in_seconds") is not None else self.base_timeout() ) request_files: typing.Optional[RequestFiles] = ( convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit)) if (files is not None and files is not omit and isinstance(files, dict)) else None ) if (request_files is None or len(request_files) == 0) and force_multipart: request_files = FORCE_MULTIPART json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit) data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart) # Compute encoded params separately to avoid passing empty list to httpx # (httpx strips existing query params from URL when params=[] is passed) _encoded_params = encode_query( jsonable_encoder( remove_none_from_dict( remove_omit_from_dict( { **(params if params is not None else {}), **( request_options.get("additional_query_parameters", {}) if request_options is not None else {} ), }, omit, ) ) ) ) _request_url = _build_url(base_url, path) _request_headers = jsonable_encoder( remove_none_from_dict( { **self.base_headers(), **(headers if headers is not None else {}), **(request_options.get("additional_headers", {}) if request_options is not None else {}), } ) ) if self.logger.is_debug(): self.logger.debug( "Making streaming HTTP request", method=method, url=_request_url, headers=_redact_headers(_request_headers), ) with self.httpx_client.stream( method=method, url=_request_url, headers=_request_headers, params=_encoded_params if _encoded_params else None, json=json_body, data=data_body, content=content, files=request_files, timeout=timeout, ) as stream: yield stream class AsyncHttpClient: def __init__( self, *, httpx_client: httpx.AsyncClient, base_timeout: typing.Callable[[], typing.Optional[float]], base_headers: typing.Callable[[], typing.Dict[str, str]], base_url: typing.Optional[typing.Callable[[], str]] = None, base_max_retries: int = 2, async_base_headers: typing.Optional[typing.Callable[[], typing.Awaitable[typing.Dict[str, str]]]] = None, logging_config: typing.Optional[typing.Union[LogConfig, Logger]] = None, ): self.base_url = base_url self.base_timeout = base_timeout self.base_headers = base_headers self.base_max_retries = base_max_retries self.async_base_headers = async_base_headers self.httpx_client = httpx_client self.logger = create_logger(logging_config) async def _get_headers(self) -> typing.Dict[str, str]: if self.async_base_headers is not None: return await self.async_base_headers() return self.base_headers() def get_base_url(self, maybe_base_url: typing.Optional[str]) -> str: base_url = maybe_base_url if self.base_url is not None and base_url is None: base_url = self.base_url() if base_url is None: raise ValueError("A base_url is required to make this request, please provide one and try again.") return base_url async def request( self, path: typing.Optional[str] = None, *, method: str, base_url: typing.Optional[str] = None, params: typing.Optional[typing.Dict[str, typing.Any]] = None, json: typing.Optional[typing.Any] = None, data: typing.Optional[typing.Any] = None, content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None, files: typing.Optional[ typing.Union[ typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]], typing.List[typing.Tuple[str, File]], ] ] = None, headers: typing.Optional[typing.Dict[str, typing.Any]] = None, request_options: typing.Optional[RequestOptions] = None, retries: int = 0, omit: typing.Optional[typing.Any] = None, force_multipart: typing.Optional[bool] = None, ) -> httpx.Response: base_url = self.get_base_url(base_url) timeout = ( request_options.get("timeout_in_seconds") if request_options is not None and request_options.get("timeout_in_seconds") is not None else self.base_timeout() ) request_files: typing.Optional[RequestFiles] = ( convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit)) if (files is not None and files is not omit and isinstance(files, dict)) else None ) if (request_files is None or len(request_files) == 0) and force_multipart: request_files = FORCE_MULTIPART json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit) data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart) # Get headers (supports async token providers) _headers = await self._get_headers() # Compute encoded params separately to avoid passing empty list to httpx # (httpx strips existing query params from URL when params=[] is passed) _encoded_params = encode_query( jsonable_encoder( remove_none_from_dict( remove_omit_from_dict( { **(params if params is not None else {}), **( request_options.get("additional_query_parameters", {}) or {} if request_options is not None else {} ), }, omit, ) ) ) ) _request_url = _build_url(base_url, path) _request_headers = jsonable_encoder( remove_none_from_dict( { **_headers, **(headers if headers is not None else {}), **(request_options.get("additional_headers", {}) or {} if request_options is not None else {}), } ) ) if self.logger.is_debug(): self.logger.debug( "Making HTTP request", method=method, url=_request_url, headers=_redact_headers(_request_headers), has_body=json_body is not None or data_body is not None, ) max_retries: int = ( request_options.get("max_retries", self.base_max_retries) if request_options is not None else self.base_max_retries ) try: response = await self.httpx_client.request( method=method, url=_request_url, headers=_request_headers, params=_encoded_params if _encoded_params else None, json=json_body, data=data_body, content=content, files=request_files, timeout=timeout, ) except (httpx.ConnectError, httpx.RemoteProtocolError): if retries < max_retries: await asyncio.sleep(_retry_timeout_from_retries(retries=retries)) return await self.request( path=path, method=method, base_url=base_url, params=params, json=json, data=data, content=content, files=files, headers=headers, request_options=request_options, retries=retries + 1, omit=omit, force_multipart=force_multipart, ) raise if _should_retry(response=response): if retries < max_retries: await asyncio.sleep(_retry_timeout(response=response, retries=retries)) return await self.request( path=path, method=method, base_url=base_url, params=params, json=json, data=data, content=content, files=files, headers=headers, request_options=request_options, retries=retries + 1, omit=omit, force_multipart=force_multipart, ) if self.logger.is_debug(): if 200 <= response.status_code < 400: self.logger.debug( "HTTP request succeeded", method=method, url=_request_url, status_code=response.status_code, ) if self.logger.is_error(): if response.status_code >= 400: self.logger.error( "HTTP request failed with error status", method=method, url=_request_url, status_code=response.status_code, ) return response @asynccontextmanager async def stream( self, path: typing.Optional[str] = None, *, method: str, base_url: typing.Optional[str] = None, params: typing.Optional[typing.Dict[str, typing.Any]] = None, json: typing.Optional[typing.Any] = None, data: typing.Optional[typing.Any] = None, content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None, files: typing.Optional[ typing.Union[ typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]], typing.List[typing.Tuple[str, File]], ] ] = None, headers: typing.Optional[typing.Dict[str, typing.Any]] = None, request_options: typing.Optional[RequestOptions] = None, retries: int = 0, omit: typing.Optional[typing.Any] = None, force_multipart: typing.Optional[bool] = None, ) -> typing.AsyncIterator[httpx.Response]: base_url = self.get_base_url(base_url) timeout = ( request_options.get("timeout_in_seconds") if request_options is not None and request_options.get("timeout_in_seconds") is not None else self.base_timeout() ) request_files: typing.Optional[RequestFiles] = ( convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit)) if (files is not None and files is not omit and isinstance(files, dict)) else None ) if (request_files is None or len(request_files) == 0) and force_multipart: request_files = FORCE_MULTIPART json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit) data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart) # Get headers (supports async token providers) _headers = await self._get_headers() # Compute encoded params separately to avoid passing empty list to httpx # (httpx strips existing query params from URL when params=[] is passed) _encoded_params = encode_query( jsonable_encoder( remove_none_from_dict( remove_omit_from_dict( { **(params if params is not None else {}), **( request_options.get("additional_query_parameters", {}) if request_options is not None else {} ), }, omit=omit, ) ) ) ) _request_url = _build_url(base_url, path) _request_headers = jsonable_encoder( remove_none_from_dict( { **_headers, **(headers if headers is not None else {}), **(request_options.get("additional_headers", {}) if request_options is not None else {}), } ) ) if self.logger.is_debug(): self.logger.debug( "Making streaming HTTP request", method=method, url=_request_url, headers=_redact_headers(_request_headers), ) async with self.httpx_client.stream( method=method, url=_request_url, headers=_request_headers, params=_encoded_params if _encoded_params else None, json=json_body, data=data_body, content=content, files=request_files, timeout=timeout, ) as stream: yield stream ================================================ FILE: src/cohere/core/http_response.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Dict, Generic, TypeVar import httpx # Generic to represent the underlying type of the data wrapped by the HTTP response. T = TypeVar("T") class BaseHttpResponse: """Minimalist HTTP response wrapper that exposes response headers and status code.""" _response: httpx.Response def __init__(self, response: httpx.Response): self._response = response @property def headers(self) -> Dict[str, str]: return dict(self._response.headers) @property def status_code(self) -> int: return self._response.status_code class HttpResponse(Generic[T], BaseHttpResponse): """HTTP response wrapper that exposes response headers and data.""" _data: T def __init__(self, response: httpx.Response, data: T): super().__init__(response) self._data = data @property def data(self) -> T: return self._data def close(self) -> None: self._response.close() class AsyncHttpResponse(Generic[T], BaseHttpResponse): """HTTP response wrapper that exposes response headers and data.""" _data: T def __init__(self, response: httpx.Response, data: T): super().__init__(response) self._data = data @property def data(self) -> T: return self._data async def close(self) -> None: await self._response.aclose() ================================================ FILE: src/cohere/core/http_sse/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from ._api import EventSource, aconnect_sse, connect_sse from ._exceptions import SSEError from ._models import ServerSentEvent _dynamic_imports: typing.Dict[str, str] = { "EventSource": "._api", "SSEError": "._exceptions", "ServerSentEvent": "._models", "aconnect_sse": "._api", "connect_sse": "._api", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["EventSource", "SSEError", "ServerSentEvent", "aconnect_sse", "connect_sse"] ================================================ FILE: src/cohere/core/http_sse/_api.py ================================================ # This file was auto-generated by Fern from our API Definition. import re from contextlib import asynccontextmanager, contextmanager from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast import httpx from ._decoders import SSEDecoder from ._exceptions import SSEError from ._models import ServerSentEvent class EventSource: def __init__(self, response: httpx.Response) -> None: self._response = response def _check_content_type(self) -> None: content_type = self._response.headers.get("content-type", "").partition(";")[0] if "text/event-stream" not in content_type: raise SSEError( f"Expected response header Content-Type to contain 'text/event-stream', got {content_type!r}" ) def _get_charset(self) -> str: """Extract charset from Content-Type header, fallback to UTF-8.""" content_type = self._response.headers.get("content-type", "") # Parse charset parameter using regex charset_match = re.search(r"charset=([^;\s]+)", content_type, re.IGNORECASE) if charset_match: charset = charset_match.group(1).strip("\"'") # Validate that it's a known encoding try: # Test if the charset is valid by trying to encode/decode "test".encode(charset).decode(charset) return charset except (LookupError, UnicodeError): # If charset is invalid, fall back to UTF-8 pass # Default to UTF-8 if no charset specified or invalid charset return "utf-8" @property def response(self) -> httpx.Response: return self._response def iter_sse(self) -> Iterator[ServerSentEvent]: self._check_content_type() decoder = SSEDecoder() charset = self._get_charset() buffer = "" for chunk in self._response.iter_bytes(): # Decode chunk using detected charset text_chunk = chunk.decode(charset, errors="replace") buffer += text_chunk # Process complete lines while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.rstrip("\r") sse = decoder.decode(line) # when we reach a "\n\n" => line = '' # => decoder will attempt to return an SSE Event if sse is not None: yield sse # Process any remaining data in buffer if buffer.strip(): line = buffer.rstrip("\r") sse = decoder.decode(line) if sse is not None: yield sse async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]: self._check_content_type() decoder = SSEDecoder() lines = cast(AsyncGenerator[str, None], self._response.aiter_lines()) try: async for line in lines: line = line.rstrip("\n") sse = decoder.decode(line) if sse is not None: yield sse finally: await lines.aclose() @contextmanager def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any) -> Iterator[EventSource]: headers = kwargs.pop("headers", {}) headers["Accept"] = "text/event-stream" headers["Cache-Control"] = "no-store" with client.stream(method, url, headers=headers, **kwargs) as response: yield EventSource(response) @asynccontextmanager async def aconnect_sse( client: httpx.AsyncClient, method: str, url: str, **kwargs: Any, ) -> AsyncIterator[EventSource]: headers = kwargs.pop("headers", {}) headers["Accept"] = "text/event-stream" headers["Cache-Control"] = "no-store" async with client.stream(method, url, headers=headers, **kwargs) as response: yield EventSource(response) ================================================ FILE: src/cohere/core/http_sse/_decoders.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import List, Optional from ._models import ServerSentEvent class SSEDecoder: def __init__(self) -> None: self._event = "" self._data: List[str] = [] self._last_event_id = "" self._retry: Optional[int] = None def decode(self, line: str) -> Optional[ServerSentEvent]: # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 if not line: if not self._event and not self._data and not self._last_event_id and self._retry is None: return None sse = ServerSentEvent( event=self._event, data="\n".join(self._data), id=self._last_event_id, retry=self._retry, ) # NOTE: as per the SSE spec, do not reset last_event_id. self._event = "" self._data = [] self._retry = None return sse if line.startswith(":"): return None fieldname, _, value = line.partition(":") if value.startswith(" "): value = value[1:] if fieldname == "event": self._event = value elif fieldname == "data": self._data.append(value) elif fieldname == "id": if "\0" in value: pass else: self._last_event_id = value elif fieldname == "retry": try: self._retry = int(value) except (TypeError, ValueError): pass else: pass # Field is ignored. return None ================================================ FILE: src/cohere/core/http_sse/_exceptions.py ================================================ # This file was auto-generated by Fern from our API Definition. import httpx class SSEError(httpx.TransportError): pass ================================================ FILE: src/cohere/core/http_sse/_models.py ================================================ # This file was auto-generated by Fern from our API Definition. import json from dataclasses import dataclass from typing import Any, Optional @dataclass(frozen=True) class ServerSentEvent: event: str = "message" data: str = "" id: str = "" retry: Optional[int] = None def json(self) -> Any: """Parse the data field as JSON.""" return json.loads(self.data) ================================================ FILE: src/cohere/core/jsonable_encoder.py ================================================ # This file was auto-generated by Fern from our API Definition. """ jsonable_encoder converts a Python object to a JSON-friendly dict (e.g. datetimes to strings, Pydantic models to dicts). Taken from FastAPI, and made a bit simpler https://github.com/tiangolo/fastapi/blob/master/fastapi/encoders.py """ import base64 import dataclasses import datetime as dt from enum import Enum from pathlib import PurePath from types import GeneratorType from typing import Any, Callable, Dict, List, Optional, Set, Union import pydantic from .datetime_utils import serialize_datetime from .pydantic_utilities import ( IS_PYDANTIC_V2, encode_by_type, to_jsonable_with_fallback, ) SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any: custom_encoder = custom_encoder or {} # Generated SDKs use Ellipsis (`...`) as the sentinel value for "OMIT". # OMIT values should be excluded from serialized payloads. if obj is Ellipsis: return None if custom_encoder: if type(obj) in custom_encoder: return custom_encoder[type(obj)](obj) else: for encoder_type, encoder_instance in custom_encoder.items(): if isinstance(obj, encoder_type): return encoder_instance(obj) if isinstance(obj, pydantic.BaseModel): if IS_PYDANTIC_V2: encoder = getattr(obj.model_config, "json_encoders", {}) # type: ignore # Pydantic v2 else: encoder = getattr(obj.__config__, "json_encoders", {}) # type: ignore # Pydantic v1 if custom_encoder: encoder.update(custom_encoder) obj_dict = obj.dict(by_alias=True) if "__root__" in obj_dict: obj_dict = obj_dict["__root__"] if "root" in obj_dict: obj_dict = obj_dict["root"] return jsonable_encoder(obj_dict, custom_encoder=encoder) if dataclasses.is_dataclass(obj): obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder(obj_dict, custom_encoder=custom_encoder) if isinstance(obj, bytes): return base64.b64encode(obj).decode("utf-8") if isinstance(obj, Enum): return obj.value if isinstance(obj, PurePath): return str(obj) if isinstance(obj, (str, int, float, type(None))): return obj if isinstance(obj, dt.datetime): return serialize_datetime(obj) if isinstance(obj, dt.date): return str(obj) if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if key in allowed_keys: if value is Ellipsis: continue encoded_key = jsonable_encoder(key, custom_encoder=custom_encoder) encoded_value = jsonable_encoder(value, custom_encoder=custom_encoder) encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): encoded_list = [] for item in obj: if item is Ellipsis: continue encoded_list.append(jsonable_encoder(item, custom_encoder=custom_encoder)) return encoded_list def fallback_serializer(o: Any) -> Any: attempt_encode = encode_by_type(o) if attempt_encode is not None: return attempt_encode try: data = dict(o) except Exception as e: errors: List[Exception] = [] errors.append(e) try: data = vars(o) except Exception as e: errors.append(e) raise ValueError(errors) from e return jsonable_encoder(data, custom_encoder=custom_encoder) return to_jsonable_with_fallback(obj, fallback_serializer) ================================================ FILE: src/cohere/core/logging.py ================================================ # This file was auto-generated by Fern from our API Definition. import logging import typing LogLevel = typing.Literal["debug", "info", "warn", "error"] _LOG_LEVEL_MAP: typing.Dict[LogLevel, int] = { "debug": 1, "info": 2, "warn": 3, "error": 4, } class ILogger(typing.Protocol): def debug(self, message: str, **kwargs: typing.Any) -> None: ... def info(self, message: str, **kwargs: typing.Any) -> None: ... def warn(self, message: str, **kwargs: typing.Any) -> None: ... def error(self, message: str, **kwargs: typing.Any) -> None: ... class ConsoleLogger: _logger: logging.Logger def __init__(self) -> None: self._logger = logging.getLogger("fern") if not self._logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s")) self._logger.addHandler(handler) self._logger.setLevel(logging.DEBUG) def debug(self, message: str, **kwargs: typing.Any) -> None: self._logger.debug(message, extra=kwargs) def info(self, message: str, **kwargs: typing.Any) -> None: self._logger.info(message, extra=kwargs) def warn(self, message: str, **kwargs: typing.Any) -> None: self._logger.warning(message, extra=kwargs) def error(self, message: str, **kwargs: typing.Any) -> None: self._logger.error(message, extra=kwargs) class LogConfig(typing.TypedDict, total=False): level: LogLevel logger: ILogger silent: bool class Logger: _level: int _logger: ILogger _silent: bool def __init__(self, *, level: LogLevel, logger: ILogger, silent: bool) -> None: self._level = _LOG_LEVEL_MAP[level] self._logger = logger self._silent = silent def _should_log(self, level: LogLevel) -> bool: return not self._silent and self._level <= _LOG_LEVEL_MAP[level] def is_debug(self) -> bool: return self._should_log("debug") def is_info(self) -> bool: return self._should_log("info") def is_warn(self) -> bool: return self._should_log("warn") def is_error(self) -> bool: return self._should_log("error") def debug(self, message: str, **kwargs: typing.Any) -> None: if self.is_debug(): self._logger.debug(message, **kwargs) def info(self, message: str, **kwargs: typing.Any) -> None: if self.is_info(): self._logger.info(message, **kwargs) def warn(self, message: str, **kwargs: typing.Any) -> None: if self.is_warn(): self._logger.warn(message, **kwargs) def error(self, message: str, **kwargs: typing.Any) -> None: if self.is_error(): self._logger.error(message, **kwargs) _default_logger: Logger = Logger(level="info", logger=ConsoleLogger(), silent=True) def create_logger(config: typing.Optional[typing.Union[LogConfig, Logger]] = None) -> Logger: if config is None: return _default_logger if isinstance(config, Logger): return config return Logger( level=config.get("level", "info"), logger=config.get("logger", ConsoleLogger()), silent=config.get("silent", True), ) ================================================ FILE: src/cohere/core/parse_error.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Any, Dict, Optional class ParsingError(Exception): """ Raised when the SDK fails to parse/validate a response from the server. This typically indicates that the server returned a response whose shape does not match the expected schema. """ headers: Optional[Dict[str, str]] status_code: Optional[int] body: Any cause: Optional[Exception] def __init__( self, *, headers: Optional[Dict[str, str]] = None, status_code: Optional[int] = None, body: Any = None, cause: Optional[Exception] = None, ) -> None: self.headers = headers self.status_code = status_code self.body = body self.cause = cause super().__init__() if cause is not None: self.__cause__ = cause def __str__(self) -> str: cause_str = f", cause: {self.cause}" if self.cause is not None else "" return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}{cause_str}" ================================================ FILE: src/cohere/core/pydantic_utilities.py ================================================ # This file was auto-generated by Fern from our API Definition. # nopycln: file import datetime as dt import inspect import json import logging from collections import defaultdict from dataclasses import asdict from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast, ) import pydantic import typing_extensions from pydantic.fields import FieldInfo as _FieldInfo _logger = logging.getLogger(__name__) if TYPE_CHECKING: from .http_sse._models import ServerSentEvent IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.") if IS_PYDANTIC_V2: _datetime_adapter = pydantic.TypeAdapter(dt.datetime) # type: ignore[attr-defined] _date_adapter = pydantic.TypeAdapter(dt.date) # type: ignore[attr-defined] def parse_datetime(value: Any) -> dt.datetime: # type: ignore[misc] if isinstance(value, dt.datetime): return value return _datetime_adapter.validate_python(value) def parse_date(value: Any) -> dt.date: # type: ignore[misc] if isinstance(value, dt.datetime): return value.date() if isinstance(value, dt.date): return value return _date_adapter.validate_python(value) # Avoid importing from pydantic.v1 to maintain Python 3.14 compatibility. from typing import get_args as get_args # type: ignore[assignment] from typing import get_origin as get_origin # type: ignore[assignment] def is_literal_type(tp: Optional[Type[Any]]) -> bool: # type: ignore[misc] return typing_extensions.get_origin(tp) is typing_extensions.Literal def is_union(tp: Optional[Type[Any]]) -> bool: # type: ignore[misc] return tp is Union or typing_extensions.get_origin(tp) is Union # type: ignore[comparison-overlap] # Inline encoders_by_type to avoid importing from pydantic.v1.json import re as _re from collections import deque as _deque from decimal import Decimal as _Decimal from enum import Enum as _Enum from ipaddress import ( IPv4Address as _IPv4Address, ) from ipaddress import ( IPv4Interface as _IPv4Interface, ) from ipaddress import ( IPv4Network as _IPv4Network, ) from ipaddress import ( IPv6Address as _IPv6Address, ) from ipaddress import ( IPv6Interface as _IPv6Interface, ) from ipaddress import ( IPv6Network as _IPv6Network, ) from pathlib import Path as _Path from types import GeneratorType as _GeneratorType from uuid import UUID as _UUID from pydantic.fields import FieldInfo as ModelField # type: ignore[no-redef, assignment] def _decimal_encoder(dec_value: Any) -> Any: if dec_value.as_tuple().exponent >= 0: return int(dec_value) return float(dec_value) encoders_by_type: Dict[Type[Any], Callable[[Any], Any]] = { # type: ignore[no-redef] bytes: lambda o: o.decode(), dt.date: lambda o: o.isoformat(), dt.datetime: lambda o: o.isoformat(), dt.time: lambda o: o.isoformat(), dt.timedelta: lambda td: td.total_seconds(), _Decimal: _decimal_encoder, _Enum: lambda o: o.value, frozenset: list, _deque: list, _GeneratorType: list, _IPv4Address: str, _IPv4Interface: str, _IPv4Network: str, _IPv6Address: str, _IPv6Interface: str, _IPv6Network: str, _Path: str, _re.Pattern: lambda o: o.pattern, set: list, _UUID: str, } else: from pydantic.datetime_parse import parse_date as parse_date # type: ignore[no-redef] from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore[no-redef] from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined, no-redef, assignment] from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[no-redef] from pydantic.typing import get_args as get_args # type: ignore[no-redef] from pydantic.typing import get_origin as get_origin # type: ignore[no-redef] from pydantic.typing import is_literal_type as is_literal_type # type: ignore[no-redef, assignment] from pydantic.typing import is_union as is_union # type: ignore[no-redef] from .datetime_utils import serialize_datetime from .serialization import convert_and_respect_annotation_metadata from typing_extensions import TypeAlias T = TypeVar("T") Model = TypeVar("Model", bound=pydantic.BaseModel) def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: """ Extract the discriminator field name and union variants from a discriminated union type. Supports Annotated[Union[...], Field(discriminator=...)] patterns. Returns (discriminator, variants) or (None, None) if not a discriminated union. """ origin = typing_extensions.get_origin(type_) if origin is typing_extensions.Annotated: args = typing_extensions.get_args(type_) if len(args) >= 2: inner_type = args[0] # Check annotations for discriminator discriminator = None for annotation in args[1:]: if hasattr(annotation, "discriminator"): discriminator = getattr(annotation, "discriminator", None) break if discriminator: inner_origin = typing_extensions.get_origin(inner_type) if inner_origin is Union: variants = list(typing_extensions.get_args(inner_type)) return discriminator, variants return None, None def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: """Get the type annotation of a field from a Pydantic model.""" if IS_PYDANTIC_V2: fields = getattr(model, "model_fields", {}) field_info = fields.get(field_name) if field_info: return cast(Optional[Type[Any]], field_info.annotation) else: fields = getattr(model, "__fields__", {}) field_info = fields.get(field_name) if field_info: return cast(Optional[Type[Any]], field_info.outer_type_) return None def _find_variant_by_discriminator( variants: List[Type[Any]], discriminator: str, discriminator_value: Any, ) -> Optional[Type[Any]]: """Find the union variant that matches the discriminator value.""" for variant in variants: if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): continue disc_annotation = _get_field_annotation(variant, discriminator) if disc_annotation and is_literal_type(disc_annotation): literal_args = get_args(disc_annotation) if literal_args and literal_args[0] == discriminator_value: return variant return None def _is_string_type(type_: Type[Any]) -> bool: """Check if a type is str or Optional[str].""" if type_ is str: return True origin = typing_extensions.get_origin(type_) if origin is Union: args = typing_extensions.get_args(type_) # Optional[str] = Union[str, None] non_none_args = [a for a in args if a is not type(None)] if len(non_none_args) == 1 and non_none_args[0] is str: return True return False def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. Handles two scenarios based on where the discriminator field is located: 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. The union describes the data content, not the SSE envelope. -> Returns: json.loads(data) parsed into the type Example: ChatStreamResponse with discriminator='type' Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. The union describes the full SSE event structure. -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string Example: JobStreamResponse with discriminator='event' Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") Output: JobStreamResponse_Error with data as ErrorData object But for variants where data is str (like STATUS_UPDATE): Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") Output: JobStreamResponse_StatusUpdate with data as string (not parsed) Args: sse: The ServerSentEvent object to parse type_: The target discriminated union type Returns: The parsed object of type T Note: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) discriminator, variants = _get_discriminator_and_variants(type_) if discriminator is None or variants is None: # Not a discriminated union - parse the data field as JSON data_value = sse_event.get("data") if isinstance(data_value, str) and data_value: try: parsed_data = json.loads(data_value) return parse_obj_as(type_, parsed_data) except json.JSONDecodeError as e: _logger.warning( "Failed to parse SSE data field as JSON: %s, data: %s", e, data_value[:100] if len(data_value) > 100 else data_value, ) return parse_obj_as(type_, sse_event) data_value = sse_event.get("data") # Check if discriminator is at the top level (event-level discrimination) if discriminator in sse_event: # Case 2: Event-level discrimination # Find the matching variant to check if 'data' field needs JSON parsing disc_value = sse_event.get(discriminator) matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) if matching_variant is not None: # Check what type the variant expects for 'data' data_type = _get_field_annotation(matching_variant, "data") if data_type is not None and not _is_string_type(data_type): # Variant expects non-string data - parse JSON if isinstance(data_value, str) and data_value: try: parsed_data = json.loads(data_value) new_object = dict(sse_event) new_object["data"] = parsed_data return parse_obj_as(type_, new_object) except json.JSONDecodeError as e: _logger.warning( "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", e, data_value[:100] if len(data_value) > 100 else data_value, ) # Either no matching variant, data is string type, or JSON parse failed return parse_obj_as(type_, sse_event) else: # Case 1: Data-level discrimination # The discriminator is inside the data payload - extract and parse data only if isinstance(data_value, str) and data_value: try: parsed_data = json.loads(data_value) return parse_obj_as(type_, parsed_data) except json.JSONDecodeError as e: _logger.warning( "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", e, data_value[:100] if len(data_value) > 100 else data_value, ) return parse_obj_as(type_, sse_event) def parse_obj_as(type_: Type[T], object_: Any) -> T: # convert_and_respect_annotation_metadata is required for TypedDict aliasing. # # For Pydantic models, whether we should pre-dealias depends on how the model encodes aliasing: # - If the model uses real Pydantic aliases (pydantic.Field(alias=...)), then we must pass wire keys through # unchanged so Pydantic can validate them. # - If the model encodes aliasing only via FieldMetadata annotations, then we MUST pre-dealias because Pydantic # will not recognize those aliases during validation. if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): has_pydantic_aliases = False if IS_PYDANTIC_V2: for field_name, field_info in getattr(type_, "model_fields", {}).items(): # type: ignore[attr-defined] alias = getattr(field_info, "alias", None) if alias is not None and alias != field_name: has_pydantic_aliases = True break else: for field in getattr(type_, "__fields__", {}).values(): alias = getattr(field, "alias", None) name = getattr(field, "name", None) if alias is not None and name is not None and alias != name: has_pydantic_aliases = True break dealiased_object = ( object_ if has_pydantic_aliases else convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read") ) else: dealiased_object = convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read") if IS_PYDANTIC_V2: adapter = pydantic.TypeAdapter(type_) # type: ignore[attr-defined] return adapter.validate_python(dealiased_object) return pydantic.parse_obj_as(type_, dealiased_object) def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any]) -> Any: if IS_PYDANTIC_V2: from pydantic_core import to_jsonable_python return to_jsonable_python(obj, fallback=fallback_serializer) return fallback_serializer(obj) class UniversalBaseModel(pydantic.BaseModel): if IS_PYDANTIC_V2: model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( # type: ignore[typeddict-unknown-key] # Allow fields beginning with `model_` to be used in the model protected_namespaces=(), ) @pydantic.model_validator(mode="before") # type: ignore[attr-defined] @classmethod def _coerce_field_names_to_aliases(cls, data: Any) -> Any: """ Accept Python field names in input by rewriting them to their Pydantic aliases, while avoiding silent collisions when a key could refer to multiple fields. """ if not isinstance(data, Mapping): return data fields = getattr(cls, "model_fields", {}) # type: ignore[attr-defined] name_to_alias: Dict[str, str] = {} alias_to_name: Dict[str, str] = {} for name, field_info in fields.items(): alias = getattr(field_info, "alias", None) or name name_to_alias[name] = alias if alias != name: alias_to_name[alias] = name # Detect ambiguous keys: a key that is an alias for one field and a name for another. ambiguous_keys = set(alias_to_name.keys()).intersection(set(name_to_alias.keys())) for key in ambiguous_keys: if key in data and name_to_alias[key] not in data: raise ValueError( f"Ambiguous input key '{key}': it is both a field name and an alias. " "Provide the explicit alias key to disambiguate." ) original_keys = set(data.keys()) rewritten: Dict[str, Any] = dict(data) for name, alias in name_to_alias.items(): if alias != name and name in original_keys and alias not in rewritten: rewritten[alias] = rewritten.pop(name) return rewritten @pydantic.model_serializer(mode="plain", when_used="json") # type: ignore[attr-defined] def serialize_model(self) -> Any: # type: ignore[name-defined] serialized = self.dict() # type: ignore[attr-defined] data = {k: serialize_datetime(v) if isinstance(v, dt.datetime) else v for k, v in serialized.items()} return data else: class Config: smart_union = True json_encoders = {dt.datetime: serialize_datetime} @pydantic.root_validator(pre=True) def _coerce_field_names_to_aliases(cls, values: Any) -> Any: """ Pydantic v1 equivalent of _coerce_field_names_to_aliases. """ if not isinstance(values, Mapping): return values fields = getattr(cls, "__fields__", {}) name_to_alias: Dict[str, str] = {} alias_to_name: Dict[str, str] = {} for name, field in fields.items(): alias = getattr(field, "alias", None) or name name_to_alias[name] = alias if alias != name: alias_to_name[alias] = name ambiguous_keys = set(alias_to_name.keys()).intersection(set(name_to_alias.keys())) for key in ambiguous_keys: if key in values and name_to_alias[key] not in values: raise ValueError( f"Ambiguous input key '{key}': it is both a field name and an alias. " "Provide the explicit alias key to disambiguate." ) original_keys = set(values.keys()) rewritten: Dict[str, Any] = dict(values) for name, alias in name_to_alias.items(): if alias != name and name in original_keys and alias not in rewritten: rewritten[alias] = rewritten.pop(name) return rewritten @classmethod def model_construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model": dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read") return cls.construct(_fields_set, **dealiased_object) @classmethod def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model": dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read") if IS_PYDANTIC_V2: return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc] return super().construct(_fields_set, **dealiased_object) def json(self, **kwargs: Any) -> str: kwargs_with_defaults = { "by_alias": True, "exclude_unset": True, **kwargs, } if IS_PYDANTIC_V2: return super().model_dump_json(**kwargs_with_defaults) # type: ignore[misc] return super().json(**kwargs_with_defaults) def dict(self, **kwargs: Any) -> Dict[str, Any]: """ Override the default dict method to `exclude_unset` by default. This function patches `exclude_unset` to work include fields within non-None default values. """ # Note: the logic here is multiplexed given the levers exposed in Pydantic V1 vs V2 # Pydantic V1's .dict can be extremely slow, so we do not want to call it twice. # # We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models # that we have less control over, and this is less intrusive than custom serializers for now. if IS_PYDANTIC_V2: kwargs_with_defaults_exclude_unset = { **kwargs, "by_alias": True, "exclude_unset": True, "exclude_none": False, } kwargs_with_defaults_exclude_none = { **kwargs, "by_alias": True, "exclude_none": True, "exclude_unset": False, } dict_dump = deep_union_pydantic_dicts( super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore[misc] super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore[misc] ) else: _fields_set = self.__fields_set__.copy() fields = _get_model_fields(self.__class__) for name, field in fields.items(): if name not in _fields_set: default = _get_field_default(field) # If the default values are non-null act like they've been set # This effectively allows exclude_unset to work like exclude_none where # the latter passes through intentionally set none values. if default is not None or ("exclude_unset" in kwargs and not kwargs["exclude_unset"]): _fields_set.add(name) if default is not None: self.__fields_set__.add(name) kwargs_with_defaults_exclude_unset_include_fields = { "by_alias": True, "exclude_unset": True, "include": _fields_set, **kwargs, } dict_dump = super().dict(**kwargs_with_defaults_exclude_unset_include_fields) return cast( Dict[str, Any], convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write"), ) def _union_list_of_pydantic_dicts(source: List[Any], destination: List[Any]) -> List[Any]: converted_list: List[Any] = [] for i, item in enumerate(source): destination_value = destination[i] if isinstance(item, dict): converted_list.append(deep_union_pydantic_dicts(item, destination_value)) elif isinstance(item, list): converted_list.append(_union_list_of_pydantic_dicts(item, destination_value)) else: converted_list.append(item) return converted_list def deep_union_pydantic_dicts(source: Dict[str, Any], destination: Dict[str, Any]) -> Dict[str, Any]: for key, value in source.items(): node = destination.setdefault(key, {}) if isinstance(value, dict): deep_union_pydantic_dicts(value, node) # Note: we do not do this same processing for sets given we do not have sets of models # and given the sets are unordered, the processing of the set and matching objects would # be non-trivial. elif isinstance(value, list): destination[key] = _union_list_of_pydantic_dicts(value, node) else: destination[key] = value return destination if IS_PYDANTIC_V2: class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore[misc, name-defined, type-arg] pass UniversalRootModel: TypeAlias = V2RootModel # type: ignore[misc] else: UniversalRootModel: TypeAlias = UniversalBaseModel # type: ignore[misc, no-redef] def encode_by_type(o: Any) -> Any: encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in encoders_by_type.items(): encoders_by_class_tuples[encoder] += (type_,) if type(o) in encoders_by_type: return encoders_by_type[type(o)](o) for encoder, classes_tuple in encoders_by_class_tuples.items(): if isinstance(o, classes_tuple): return encoder(o) def update_forward_refs(model: Type["Model"], **localns: Any) -> None: if IS_PYDANTIC_V2: model.model_rebuild(raise_errors=False) # type: ignore[attr-defined] else: model.update_forward_refs(**localns) # Mirrors Pydantic's internal typing AnyCallable = Callable[..., Any] def universal_root_validator( pre: bool = False, ) -> Callable[[AnyCallable], AnyCallable]: def decorator(func: AnyCallable) -> AnyCallable: if IS_PYDANTIC_V2: # In Pydantic v2, for RootModel we always use "before" mode # The custom validators transform the input value before the model is created return cast(AnyCallable, pydantic.model_validator(mode="before")(func)) # type: ignore[attr-defined] return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload] return decorator def universal_field_validator(field_name: str, pre: bool = False) -> Callable[[AnyCallable], AnyCallable]: def decorator(func: AnyCallable) -> AnyCallable: if IS_PYDANTIC_V2: return cast(AnyCallable, pydantic.field_validator(field_name, mode="before" if pre else "after")(func)) # type: ignore[attr-defined] return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func)) return decorator PydanticField = Union[ModelField, _FieldInfo] def _get_model_fields(model: Type["Model"]) -> Mapping[str, PydanticField]: if IS_PYDANTIC_V2: return cast(Mapping[str, PydanticField], model.model_fields) # type: ignore[attr-defined] return cast(Mapping[str, PydanticField], model.__fields__) def _get_field_default(field: PydanticField) -> Any: try: value = field.get_default() # type: ignore[union-attr] except: value = field.default if IS_PYDANTIC_V2: from pydantic_core import PydanticUndefined if value == PydanticUndefined: return None return value return value ================================================ FILE: src/cohere/core/query_encoder.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Any, Dict, List, Optional, Tuple import pydantic # Flattens dicts to be of the form {"key[subkey][subkey2]": value} where value is not a dict def traverse_query_dict(dict_flat: Dict[str, Any], key_prefix: Optional[str] = None) -> List[Tuple[str, Any]]: result = [] for k, v in dict_flat.items(): key = f"{key_prefix}[{k}]" if key_prefix is not None else k if isinstance(v, dict): result.extend(traverse_query_dict(v, key)) elif isinstance(v, list): for arr_v in v: if isinstance(arr_v, dict): result.extend(traverse_query_dict(arr_v, key)) else: result.append((key, arr_v)) else: result.append((key, v)) return result def single_query_encoder(query_key: str, query_value: Any) -> List[Tuple[str, Any]]: if isinstance(query_value, pydantic.BaseModel) or isinstance(query_value, dict): if isinstance(query_value, pydantic.BaseModel): obj_dict = query_value.dict(by_alias=True) else: obj_dict = query_value return traverse_query_dict(obj_dict, query_key) elif isinstance(query_value, list): encoded_values: List[Tuple[str, Any]] = [] for value in query_value: if isinstance(value, pydantic.BaseModel) or isinstance(value, dict): if isinstance(value, pydantic.BaseModel): obj_dict = value.dict(by_alias=True) elif isinstance(value, dict): obj_dict = value encoded_values.extend(single_query_encoder(query_key, obj_dict)) else: encoded_values.append((query_key, value)) return encoded_values return [(query_key, query_value)] def encode_query(query: Optional[Dict[str, Any]]) -> Optional[List[Tuple[str, Any]]]: if query is None: return None encoded_query = [] for k, v in query.items(): encoded_query.extend(single_query_encoder(k, v)) return encoded_query ================================================ FILE: src/cohere/core/remove_none_from_dict.py ================================================ # This file was auto-generated by Fern from our API Definition. from typing import Any, Dict, Mapping, Optional def remove_none_from_dict(original: Mapping[str, Optional[Any]]) -> Dict[str, Any]: new: Dict[str, Any] = {} for key, value in original.items(): if value is not None: new[key] = value return new ================================================ FILE: src/cohere/core/request_options.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing try: from typing import NotRequired # type: ignore except ImportError: from typing_extensions import NotRequired class RequestOptions(typing.TypedDict, total=False): """ Additional options for request-specific configuration when calling APIs via the SDK. This is used primarily as an optional final parameter for service functions. Attributes: - timeout_in_seconds: int. The number of seconds to await an API call before timing out. - max_retries: int. The max number of retries to attempt if the API call fails. - additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict - additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict - additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict - chunk_size: int. The size, in bytes, to process each chunk of data being streamed back within the response. This equates to leveraging `chunk_size` within `requests` or `httpx`, and is only leveraged for file downloads. """ timeout_in_seconds: NotRequired[int] max_retries: NotRequired[int] additional_headers: NotRequired[typing.Dict[str, typing.Any]] additional_query_parameters: NotRequired[typing.Dict[str, typing.Any]] additional_body_parameters: NotRequired[typing.Dict[str, typing.Any]] chunk_size: NotRequired[int] ================================================ FILE: src/cohere/core/serialization.py ================================================ # This file was auto-generated by Fern from our API Definition. import collections import inspect import typing import pydantic import typing_extensions class FieldMetadata: """ Metadata class used to annotate fields to provide additional information. Example: class MyDict(TypedDict): field: typing.Annotated[str, FieldMetadata(alias="field_name")] Will serialize: `{"field": "value"}` To: `{"field_name": "value"}` """ alias: str def __init__(self, *, alias: str) -> None: self.alias = alias def convert_and_respect_annotation_metadata( *, object_: typing.Any, annotation: typing.Any, inner_type: typing.Optional[typing.Any] = None, direction: typing.Literal["read", "write"], ) -> typing.Any: """ Respect the metadata annotations on a field, such as aliasing. This function effectively manipulates the dict-form of an object to respect the metadata annotations. This is primarily used for TypedDicts, which cannot support aliasing out of the box, and can be extended for additional utilities, such as defaults. Parameters ---------- object_ : typing.Any annotation : type The type we're looking to apply typing annotations from inner_type : typing.Optional[type] Returns ------- typing.Any """ if object_ is None: return None if inner_type is None: inner_type = annotation clean_type = _remove_annotations(inner_type) # Pydantic models if ( inspect.isclass(clean_type) and issubclass(clean_type, pydantic.BaseModel) and isinstance(object_, typing.Mapping) ): return _convert_mapping(object_, clean_type, direction) # TypedDicts if typing_extensions.is_typeddict(clean_type) and isinstance(object_, typing.Mapping): return _convert_mapping(object_, clean_type, direction) if ( typing_extensions.get_origin(clean_type) == typing.Dict or typing_extensions.get_origin(clean_type) == dict or clean_type == typing.Dict ) and isinstance(object_, typing.Dict): key_type = typing_extensions.get_args(clean_type)[0] value_type = typing_extensions.get_args(clean_type)[1] return { key: convert_and_respect_annotation_metadata( object_=value, annotation=annotation, inner_type=value_type, direction=direction, ) for key, value in object_.items() } # If you're iterating on a string, do not bother to coerce it to a sequence. if not isinstance(object_, str): if ( typing_extensions.get_origin(clean_type) == typing.Set or typing_extensions.get_origin(clean_type) == set or clean_type == typing.Set ) and isinstance(object_, typing.Set): inner_type = typing_extensions.get_args(clean_type)[0] return { convert_and_respect_annotation_metadata( object_=item, annotation=annotation, inner_type=inner_type, direction=direction, ) for item in object_ } elif ( ( typing_extensions.get_origin(clean_type) == typing.List or typing_extensions.get_origin(clean_type) == list or clean_type == typing.List ) and isinstance(object_, typing.List) ) or ( ( typing_extensions.get_origin(clean_type) == typing.Sequence or typing_extensions.get_origin(clean_type) == collections.abc.Sequence or clean_type == typing.Sequence ) and isinstance(object_, typing.Sequence) ): inner_type = typing_extensions.get_args(clean_type)[0] return [ convert_and_respect_annotation_metadata( object_=item, annotation=annotation, inner_type=inner_type, direction=direction, ) for item in object_ ] if typing_extensions.get_origin(clean_type) == typing.Union: # We should be able to ~relatively~ safely try to convert keys against all # member types in the union, the edge case here is if one member aliases a field # of the same name to a different name from another member # Or if another member aliases a field of the same name that another member does not. for member in typing_extensions.get_args(clean_type): object_ = convert_and_respect_annotation_metadata( object_=object_, annotation=annotation, inner_type=member, direction=direction, ) return object_ annotated_type = _get_annotation(annotation) if annotated_type is None: return object_ # If the object is not a TypedDict, a Union, or other container (list, set, sequence, etc.) # Then we can safely call it on the recursive conversion. return object_ def _convert_mapping( object_: typing.Mapping[str, object], expected_type: typing.Any, direction: typing.Literal["read", "write"], ) -> typing.Mapping[str, object]: converted_object: typing.Dict[str, object] = {} try: annotations = typing_extensions.get_type_hints(expected_type, include_extras=True) except NameError: # The TypedDict contains a circular reference, so # we use the __annotations__ attribute directly. annotations = getattr(expected_type, "__annotations__", {}) aliases_to_field_names = _get_alias_to_field_name(annotations) for key, value in object_.items(): if direction == "read" and key in aliases_to_field_names: dealiased_key = aliases_to_field_names.get(key) if dealiased_key is not None: type_ = annotations.get(dealiased_key) else: type_ = annotations.get(key) # Note you can't get the annotation by the field name if you're in read mode, so you must check the aliases map # # So this is effectively saying if we're in write mode, and we don't have a type, or if we're in read mode and we don't have an alias # then we can just pass the value through as is if type_ is None: converted_object[key] = value elif direction == "read" and key not in aliases_to_field_names: converted_object[key] = convert_and_respect_annotation_metadata( object_=value, annotation=type_, direction=direction ) else: converted_object[_alias_key(key, type_, direction, aliases_to_field_names)] = ( convert_and_respect_annotation_metadata(object_=value, annotation=type_, direction=direction) ) return converted_object def _get_annotation(type_: typing.Any) -> typing.Optional[typing.Any]: maybe_annotated_type = typing_extensions.get_origin(type_) if maybe_annotated_type is None: return None if maybe_annotated_type == typing_extensions.NotRequired: type_ = typing_extensions.get_args(type_)[0] maybe_annotated_type = typing_extensions.get_origin(type_) if maybe_annotated_type == typing_extensions.Annotated: return type_ return None def _remove_annotations(type_: typing.Any) -> typing.Any: maybe_annotated_type = typing_extensions.get_origin(type_) if maybe_annotated_type is None: return type_ if maybe_annotated_type == typing_extensions.NotRequired: return _remove_annotations(typing_extensions.get_args(type_)[0]) if maybe_annotated_type == typing_extensions.Annotated: return _remove_annotations(typing_extensions.get_args(type_)[0]) return type_ def get_alias_to_field_mapping(type_: typing.Any) -> typing.Dict[str, str]: annotations = typing_extensions.get_type_hints(type_, include_extras=True) return _get_alias_to_field_name(annotations) def get_field_to_alias_mapping(type_: typing.Any) -> typing.Dict[str, str]: annotations = typing_extensions.get_type_hints(type_, include_extras=True) return _get_field_to_alias_name(annotations) def _get_alias_to_field_name( field_to_hint: typing.Dict[str, typing.Any], ) -> typing.Dict[str, str]: aliases = {} for field, hint in field_to_hint.items(): maybe_alias = _get_alias_from_type(hint) if maybe_alias is not None: aliases[maybe_alias] = field return aliases def _get_field_to_alias_name( field_to_hint: typing.Dict[str, typing.Any], ) -> typing.Dict[str, str]: aliases = {} for field, hint in field_to_hint.items(): maybe_alias = _get_alias_from_type(hint) if maybe_alias is not None: aliases[field] = maybe_alias return aliases def _get_alias_from_type(type_: typing.Any) -> typing.Optional[str]: maybe_annotated_type = _get_annotation(type_) if maybe_annotated_type is not None: # The actual annotations are 1 onward, the first is the annotated type annotations = typing_extensions.get_args(maybe_annotated_type)[1:] for annotation in annotations: if isinstance(annotation, FieldMetadata) and annotation.alias is not None: return annotation.alias return None def _alias_key( key: str, type_: typing.Any, direction: typing.Literal["read", "write"], aliases_to_field_names: typing.Dict[str, str], ) -> str: if direction == "read": return aliases_to_field_names.get(key, key) return _get_alias_from_type(type_=type_) or key ================================================ FILE: src/cohere/core/unchecked_base_model.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import enum import inspect import sys import typing import uuid import pydantic import typing_extensions from .pydantic_utilities import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, ModelField, UniversalBaseModel, get_args, get_origin, is_literal_type, is_union, parse_date, parse_datetime, parse_obj_as, ) from .serialization import get_field_to_alias_mapping from pydantic_core import PydanticUndefined class UnionMetadata: discriminant: str def __init__(self, *, discriminant: str) -> None: self.discriminant = discriminant Model = typing.TypeVar("Model", bound=pydantic.BaseModel) def _maybe_resolve_forward_ref( type_: typing.Any, host: typing.Optional[typing.Type[typing.Any]], ) -> typing.Any: """Resolve a ForwardRef using the module where *host* is defined. Pydantic v2 + ``from __future__ import annotations`` can leave field annotations as ``list[ForwardRef('Block')]`` even after ``model_rebuild``. Without resolution, ``construct_type`` sees a ForwardRef (not a class) and skips recursive model construction, leaving nested data as raw dicts. """ if host is None or not isinstance(type_, typing.ForwardRef): return type_ mod = sys.modules.get(host.__module__) if mod is None: return type_ try: return eval(type_.__forward_arg__, vars(mod)) except Exception: return type_ class UncheckedBaseModel(UniversalBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: extra = pydantic.Extra.allow @classmethod def model_construct( cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any, ) -> "Model": # Fallback construct function to the specified override below. return cls.construct(_fields_set=_fields_set, **values) # Allow construct to not validate model # Implementation taken from: https://github.com/pydantic/pydantic/issues/1168#issuecomment-817742836 @classmethod def construct( cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any, ) -> "Model": m = cls.__new__(cls) fields_values = {} if _fields_set is None: _fields_set = set(values.keys()) fields = _get_model_fields(cls) populate_by_name = _get_is_populate_by_name(cls) field_aliases = get_field_to_alias_mapping(cls) for name, field in fields.items(): # Key here is only used to pull data from the values dict # you should always use the NAME of the field to for field_values, etc. # because that's how the object is constructed from a pydantic perspective key = field.alias if (key is None or field.alias == name) and name in field_aliases: key = field_aliases[name] if key is None or (key not in values and populate_by_name): # Added this to allow population by field name key = name if key in values: if IS_PYDANTIC_V2: type_ = field.annotation # type: ignore # Pydantic v2 else: type_ = typing.cast(typing.Type, field.outer_type_) # type: ignore # Pydantic < v1.10.15 fields_values[name] = ( construct_type(object_=values[key], type_=type_, host=cls) if type_ is not None else values[key] ) _fields_set.add(name) else: default = _get_field_default(field) fields_values[name] = default # If the default values are non-null act like they've been set # This effectively allows exclude_unset to work like exclude_none where # the latter passes through intentionally set none values. if default != None and default != PydanticUndefined: _fields_set.add(name) # Add extras back in extras = {} pydantic_alias_fields = [field.alias for field in fields.values()] internal_alias_fields = list(field_aliases.values()) for key, value in values.items(): # If the key is not a field by name, nor an alias to a field, then it's extra if (key not in pydantic_alias_fields and key not in internal_alias_fields) and key not in fields: if IS_PYDANTIC_V2: extras[key] = value else: _fields_set.add(key) fields_values[key] = value object.__setattr__(m, "__dict__", fields_values) if IS_PYDANTIC_V2: object.__setattr__(m, "__pydantic_private__", None) object.__setattr__(m, "__pydantic_extra__", extras) object.__setattr__(m, "__pydantic_fields_set__", _fields_set) else: object.__setattr__(m, "__fields_set__", _fields_set) m._init_private_attributes() # type: ignore # Pydantic v1 return m def _validate_collection_items_compatible(collection: typing.Any, target_type: typing.Type[typing.Any]) -> bool: """ Validate that all items in a collection are compatible with the target type. Args: collection: The collection to validate (list, set, or dict values) target_type: The target type to validate against Returns: True if all items are compatible, False otherwise """ if inspect.isclass(target_type) and issubclass(target_type, pydantic.BaseModel): for item in collection: try: # Try to validate the item against the target type if isinstance(item, dict): parse_obj_as(target_type, item) else: # If it's not a dict, it might already be the right type if not isinstance(item, target_type): return False except Exception: return False return True def _get_literal_field_value( inner_type: typing.Type[typing.Any], field_name: str, field: typing.Any, object_: typing.Any ) -> typing.Any: """Get the value of a Literal field from *object_*, checking both alias and field name.""" name_or_alias = get_field_to_alias_mapping(inner_type).get(field_name, field_name) pydantic_alias = getattr(field, "alias", None) if isinstance(object_, dict): if name_or_alias in object_: return object_[name_or_alias] if pydantic_alias and pydantic_alias != name_or_alias and pydantic_alias in object_: return object_[pydantic_alias] return None return getattr(object_, name_or_alias, getattr(object_, pydantic_alias, None) if pydantic_alias else None) def _literal_fields_match_strict(inner_type: typing.Type[typing.Any], object_: typing.Any) -> bool: """Return True iff every Literal-typed field in *inner_type* is **present** in *object_* and its value equals the field's declared default. This prevents models whose fields are all optional (e.g. ``FigureDetails``) from vacuously matching inputs that don't carry the discriminant key at all (e.g. ``{}`` for text blocks). For types with no Literal fields this returns True unconditionally. """ fields = _get_model_fields(inner_type) for field_name, field in fields.items(): if IS_PYDANTIC_V2: field_type = field.annotation # type: ignore # Pydantic v2 else: field_type = field.outer_type_ # type: ignore # Pydantic v1 if is_literal_type(field_type): # type: ignore[arg-type] field_default = _get_field_default(field) object_value = _get_literal_field_value(inner_type, field_name, field, object_) if field_default != object_value: return False return True def _convert_undiscriminated_union_type( union_type: typing.Type[typing.Any], object_: typing.Any, host: typing.Optional[typing.Type[typing.Any]] = None, ) -> typing.Any: inner_types = get_args(union_type) if typing.Any in inner_types: return object_ # When any union member carries a Literal discriminant field, require the # discriminant key to be present AND matching before accepting a candidate. # This prevents models with all-optional fields (e.g. FigureDetails) from # greedily matching inputs that belong to a different variant or to a # plain-dict fallback (e.g. EmptyBlockDetails = Dict[str, Any]). has_literal_discriminant = any( inspect.isclass(t) and issubclass(t, pydantic.BaseModel) and any( is_literal_type( f.annotation if IS_PYDANTIC_V2 else f.outer_type_ # type: ignore ) for f in _get_model_fields(t).values() ) for t in inner_types ) for inner_type in inner_types: # Handle lists of objects that need parsing if get_origin(inner_type) is list and isinstance(object_, list): list_inner_type = _maybe_resolve_forward_ref(get_args(inner_type)[0], host) try: if inspect.isclass(list_inner_type) and issubclass(list_inner_type, pydantic.BaseModel): # Validate that all items in the list are compatible with the target type if _validate_collection_items_compatible(object_, list_inner_type): parsed_list = [parse_obj_as(object_=item, type_=list_inner_type) for item in object_] return parsed_list except Exception: pass try: if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): if has_literal_discriminant and not _literal_fields_match_strict(inner_type, object_): continue # Attempt a validated parse until one works return parse_obj_as(inner_type, object_) except Exception: continue # First pass: try types where all literal fields match the object's values. for inner_type in inner_types: if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): if has_literal_discriminant: if not _literal_fields_match_strict(inner_type, object_): continue else: # Legacy lenient check: skip only when a Literal value is # present but doesn't match (allows absent-discriminant inputs). fields = _get_model_fields(inner_type) literal_fields_match = True for field_name, field in fields.items(): if IS_PYDANTIC_V2: field_type = field.annotation # type: ignore # Pydantic v2 else: field_type = field.outer_type_ # type: ignore # Pydantic v1 if is_literal_type(field_type): # type: ignore[arg-type] field_default = _get_field_default(field) object_value = _get_literal_field_value(inner_type, field_name, field, object_) if object_value is not None and field_default != object_value: literal_fields_match = False break if not literal_fields_match: continue try: return construct_type(object_=object_, type_=inner_type, host=host) except Exception: continue # Second pass: if no literal matches, return the first successful cast. # When a Literal discriminant is present, skip Pydantic models whose # discriminant doesn't match so that plain-dict fallback types are reached. for inner_type in inner_types: try: if has_literal_discriminant and inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): if not _literal_fields_match_strict(inner_type, object_): continue return construct_type(object_=object_, type_=inner_type, host=host) except Exception: continue def _convert_union_type( type_: typing.Type[typing.Any], object_: typing.Any, host: typing.Optional[typing.Type[typing.Any]] = None, ) -> typing.Any: base_type = get_origin(type_) or type_ union_type = type_ if base_type == typing_extensions.Annotated: # type: ignore[comparison-overlap] union_type = get_args(type_)[0] annotated_metadata = get_args(type_)[1:] for metadata in annotated_metadata: if isinstance(metadata, UnionMetadata): try: # Cast to the correct type, based on the discriminant for inner_type in get_args(union_type): try: objects_discriminant = getattr(object_, metadata.discriminant) except: objects_discriminant = object_[metadata.discriminant] if inner_type.__fields__[metadata.discriminant].default == objects_discriminant: return construct_type(object_=object_, type_=inner_type, host=host) except Exception: # Allow to fall through to our regular union handling pass return _convert_undiscriminated_union_type(union_type, object_, host) def construct_type( *, type_: typing.Type[typing.Any], object_: typing.Any, host: typing.Optional[typing.Type[typing.Any]] = None, ) -> typing.Any: """ Here we are essentially creating the same `construct` method in spirit as the above, but for all types, not just Pydantic models. The idea is to essentially attempt to coerce object_ to type_ (recursively) """ # Short circuit when dealing with optionals, don't try to coerces None to a type if object_ is None: return None base_type = get_origin(type_) or type_ is_annotated = base_type == typing_extensions.Annotated # type: ignore[comparison-overlap] maybe_annotation_members = get_args(type_) is_annotated_union = is_annotated and is_union(get_origin(maybe_annotation_members[0])) if base_type == typing.Any: # type: ignore[comparison-overlap] return object_ if base_type == dict: if not isinstance(object_, typing.Mapping): return object_ key_type, items_type = get_args(type_) key_type = _maybe_resolve_forward_ref(key_type, host) items_type = _maybe_resolve_forward_ref(items_type, host) d = { construct_type(object_=key, type_=key_type, host=host): construct_type( object_=item, type_=items_type, host=host ) for key, item in object_.items() } return d if base_type == list: if not isinstance(object_, list): return object_ inner_type = _maybe_resolve_forward_ref(get_args(type_)[0], host) return [construct_type(object_=entry, type_=inner_type, host=host) for entry in object_] if base_type == set: if not isinstance(object_, set) and not isinstance(object_, list): return object_ inner_type = _maybe_resolve_forward_ref(get_args(type_)[0], host) return {construct_type(object_=entry, type_=inner_type, host=host) for entry in object_} if is_union(base_type) or is_annotated_union: return _convert_union_type(type_, object_, host) # Cannot do an `issubclass` with a literal type, let's also just confirm we have a class before this call if ( object_ is not None and not is_literal_type(type_) and ( (inspect.isclass(base_type) and issubclass(base_type, pydantic.BaseModel)) or ( is_annotated and inspect.isclass(maybe_annotation_members[0]) and issubclass(maybe_annotation_members[0], pydantic.BaseModel) ) ) ): if IS_PYDANTIC_V2: return type_.model_construct(**object_) else: return type_.construct(**object_) if base_type == dt.datetime: try: return parse_datetime(object_) except Exception: return object_ if base_type == dt.date: try: return parse_date(object_) except Exception: return object_ if base_type == uuid.UUID: try: return uuid.UUID(object_) except Exception: return object_ if base_type == int: try: return int(object_) except Exception: return object_ if base_type == bool: try: if isinstance(object_, str): stringified_object = object_.lower() return stringified_object == "true" or stringified_object == "1" return bool(object_) except Exception: return object_ if inspect.isclass(base_type) and issubclass(base_type, enum.Enum): try: return base_type(object_) except (ValueError, KeyError): return object_ return object_ def _get_is_populate_by_name(model: typing.Type["Model"]) -> bool: if IS_PYDANTIC_V2: return model.model_config.get("populate_by_name", False) # type: ignore # Pydantic v2 return model.__config__.allow_population_by_field_name # type: ignore # Pydantic v1 from pydantic.fields import FieldInfo as _FieldInfo PydanticField = typing.Union[ModelField, _FieldInfo] # Pydantic V1 swapped the typing of __fields__'s values from ModelField to FieldInfo # And so we try to handle both V1 cases, as well as V2 (FieldInfo from model.model_fields) def _get_model_fields( model: typing.Type["Model"], ) -> typing.Mapping[str, PydanticField]: if IS_PYDANTIC_V2: return model.model_fields # type: ignore # Pydantic v2 else: return model.__fields__ # type: ignore # Pydantic v1 def _get_field_default(field: PydanticField) -> typing.Any: try: value = field.get_default() # type: ignore # Pydantic < v1.10.15 except: value = field.default if IS_PYDANTIC_V2: from pydantic_core import PydanticUndefined if value == PydanticUndefined: return None return value return value ================================================ FILE: src/cohere/datasets/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse _dynamic_imports: typing.Dict[str, str] = { "DatasetsCreateResponse": ".types", "DatasetsGetResponse": ".types", "DatasetsGetUsageResponse": ".types", "DatasetsListResponse": ".types", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["DatasetsCreateResponse", "DatasetsGetResponse", "DatasetsGetUsageResponse", "DatasetsListResponse"] ================================================ FILE: src/cohere/datasets/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing from .. import core from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from ..types.dataset_type import DatasetType from ..types.dataset_validation_status import DatasetValidationStatus from .raw_client import AsyncRawDatasetsClient, RawDatasetsClient from .types.datasets_create_response import DatasetsCreateResponse from .types.datasets_get_response import DatasetsGetResponse from .types.datasets_get_usage_response import DatasetsGetUsageResponse from .types.datasets_list_response import DatasetsListResponse # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class DatasetsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawDatasetsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawDatasetsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawDatasetsClient """ return self._raw_client def list( self, *, dataset_type: typing.Optional[str] = None, before: typing.Optional[dt.datetime] = None, after: typing.Optional[dt.datetime] = None, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, validation_status: typing.Optional[DatasetValidationStatus] = None, request_options: typing.Optional[RequestOptions] = None, ) -> DatasetsListResponse: """ List datasets that have been created. Parameters ---------- dataset_type : typing.Optional[str] optional filter by dataset type before : typing.Optional[dt.datetime] optional filter before a date after : typing.Optional[dt.datetime] optional filter after a date limit : typing.Optional[float] optional limit to number of results offset : typing.Optional[float] optional offset to start of results validation_status : typing.Optional[DatasetValidationStatus] optional filter by validation status request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsListResponse A successful response. Examples -------- import datetime from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.datasets.list( dataset_type="datasetType", before=datetime.datetime.fromisoformat( "2024-01-15 09:30:00+00:00", ), after=datetime.datetime.fromisoformat( "2024-01-15 09:30:00+00:00", ), limit=1.1, offset=1.1, validation_status="unknown", ) """ _response = self._raw_client.list( dataset_type=dataset_type, before=before, after=after, limit=limit, offset=offset, validation_status=validation_status, request_options=request_options, ) return _response.data def create( self, *, name: str, type: DatasetType, data: core.File, keep_original_file: typing.Optional[bool] = None, skip_malformed_input: typing.Optional[bool] = None, keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, text_separator: typing.Optional[str] = None, csv_delimiter: typing.Optional[str] = None, eval_data: typing.Optional[core.File] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> DatasetsCreateResponse: """ Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information. Parameters ---------- name : str The name of the uploaded dataset. type : DatasetType The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API. data : core.File See core.File for more documentation keep_original_file : typing.Optional[bool] Indicates if the original file should be stored. skip_malformed_input : typing.Optional[bool] Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field. keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail. optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass. text_separator : typing.Optional[str] Raw .txt uploads will be split into entries using the text_separator value. csv_delimiter : typing.Optional[str] The delimiter used for .csv uploads. eval_data : typing.Optional[core.File] See core.File for more documentation request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsCreateResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.datasets.create( name="name", type="embed-input", keep_original_file=True, skip_malformed_input=True, text_separator="text_separator", csv_delimiter="csv_delimiter", ) """ _response = self._raw_client.create( name=name, type=type, data=data, keep_original_file=keep_original_file, skip_malformed_input=skip_malformed_input, keep_fields=keep_fields, optional_fields=optional_fields, text_separator=text_separator, csv_delimiter=csv_delimiter, eval_data=eval_data, request_options=request_options, ) return _response.data def get_usage(self, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetUsageResponse: """ View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsGetUsageResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.datasets.get_usage() """ _response = self._raw_client.get_usage(request_options=request_options) return _response.data def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetResponse: """ Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsGetResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.datasets.get( id="id", ) """ _response = self._raw_client.get(id, request_options=request_options) return _response.data def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> typing.Dict[str, typing.Any]: """ Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- typing.Dict[str, typing.Any] A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.datasets.delete( id="id", ) """ _response = self._raw_client.delete(id, request_options=request_options) return _response.data class AsyncDatasetsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawDatasetsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawDatasetsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawDatasetsClient """ return self._raw_client async def list( self, *, dataset_type: typing.Optional[str] = None, before: typing.Optional[dt.datetime] = None, after: typing.Optional[dt.datetime] = None, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, validation_status: typing.Optional[DatasetValidationStatus] = None, request_options: typing.Optional[RequestOptions] = None, ) -> DatasetsListResponse: """ List datasets that have been created. Parameters ---------- dataset_type : typing.Optional[str] optional filter by dataset type before : typing.Optional[dt.datetime] optional filter before a date after : typing.Optional[dt.datetime] optional filter after a date limit : typing.Optional[float] optional limit to number of results offset : typing.Optional[float] optional offset to start of results validation_status : typing.Optional[DatasetValidationStatus] optional filter by validation status request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsListResponse A successful response. Examples -------- import asyncio import datetime from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.datasets.list( dataset_type="datasetType", before=datetime.datetime.fromisoformat( "2024-01-15 09:30:00+00:00", ), after=datetime.datetime.fromisoformat( "2024-01-15 09:30:00+00:00", ), limit=1.1, offset=1.1, validation_status="unknown", ) asyncio.run(main()) """ _response = await self._raw_client.list( dataset_type=dataset_type, before=before, after=after, limit=limit, offset=offset, validation_status=validation_status, request_options=request_options, ) return _response.data async def create( self, *, name: str, type: DatasetType, data: core.File, keep_original_file: typing.Optional[bool] = None, skip_malformed_input: typing.Optional[bool] = None, keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, text_separator: typing.Optional[str] = None, csv_delimiter: typing.Optional[str] = None, eval_data: typing.Optional[core.File] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> DatasetsCreateResponse: """ Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information. Parameters ---------- name : str The name of the uploaded dataset. type : DatasetType The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API. data : core.File See core.File for more documentation keep_original_file : typing.Optional[bool] Indicates if the original file should be stored. skip_malformed_input : typing.Optional[bool] Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field. keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail. optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass. text_separator : typing.Optional[str] Raw .txt uploads will be split into entries using the text_separator value. csv_delimiter : typing.Optional[str] The delimiter used for .csv uploads. eval_data : typing.Optional[core.File] See core.File for more documentation request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsCreateResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.datasets.create( name="name", type="embed-input", keep_original_file=True, skip_malformed_input=True, text_separator="text_separator", csv_delimiter="csv_delimiter", ) asyncio.run(main()) """ _response = await self._raw_client.create( name=name, type=type, data=data, keep_original_file=keep_original_file, skip_malformed_input=skip_malformed_input, keep_fields=keep_fields, optional_fields=optional_fields, text_separator=text_separator, csv_delimiter=csv_delimiter, eval_data=eval_data, request_options=request_options, ) return _response.data async def get_usage(self, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetUsageResponse: """ View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsGetUsageResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.datasets.get_usage() asyncio.run(main()) """ _response = await self._raw_client.get_usage(request_options=request_options) return _response.data async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetResponse: """ Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DatasetsGetResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.datasets.get( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.get(id, request_options=request_options) return _response.data async def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> typing.Dict[str, typing.Any]: """ Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- typing.Dict[str, typing.Any] A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.datasets.delete( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.delete(id, request_options=request_options) return _response.data ================================================ FILE: src/cohere/datasets/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing from json.decoder import JSONDecodeError from .. import core from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.datetime_utils import serialize_datetime from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.client_closed_request_error import ClientClosedRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.gateway_timeout_error import GatewayTimeoutError from ..errors.internal_server_error import InternalServerError from ..errors.invalid_token_error import InvalidTokenError from ..errors.not_found_error import NotFoundError from ..errors.not_implemented_error import NotImplementedError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.too_many_requests_error import TooManyRequestsError from ..errors.unauthorized_error import UnauthorizedError from ..errors.unprocessable_entity_error import UnprocessableEntityError from ..types.dataset_type import DatasetType from ..types.dataset_validation_status import DatasetValidationStatus from .types.datasets_create_response import DatasetsCreateResponse from .types.datasets_get_response import DatasetsGetResponse from .types.datasets_get_usage_response import DatasetsGetUsageResponse from .types.datasets_list_response import DatasetsListResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawDatasetsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def list( self, *, dataset_type: typing.Optional[str] = None, before: typing.Optional[dt.datetime] = None, after: typing.Optional[dt.datetime] = None, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, validation_status: typing.Optional[DatasetValidationStatus] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[DatasetsListResponse]: """ List datasets that have been created. Parameters ---------- dataset_type : typing.Optional[str] optional filter by dataset type before : typing.Optional[dt.datetime] optional filter before a date after : typing.Optional[dt.datetime] optional filter after a date limit : typing.Optional[float] optional limit to number of results offset : typing.Optional[float] optional offset to start of results validation_status : typing.Optional[DatasetValidationStatus] optional filter by validation status request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DatasetsListResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v1/datasets", method="GET", params={ "datasetType": dataset_type, "before": serialize_datetime(before) if before is not None else None, "after": serialize_datetime(after) if after is not None else None, "limit": limit, "offset": offset, "validationStatus": validation_status, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsListResponse, construct_type( type_=DatasetsListResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def create( self, *, name: str, type: DatasetType, data: core.File, keep_original_file: typing.Optional[bool] = None, skip_malformed_input: typing.Optional[bool] = None, keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, text_separator: typing.Optional[str] = None, csv_delimiter: typing.Optional[str] = None, eval_data: typing.Optional[core.File] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[DatasetsCreateResponse]: """ Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information. Parameters ---------- name : str The name of the uploaded dataset. type : DatasetType The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API. data : core.File See core.File for more documentation keep_original_file : typing.Optional[bool] Indicates if the original file should be stored. skip_malformed_input : typing.Optional[bool] Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field. keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail. optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass. text_separator : typing.Optional[str] Raw .txt uploads will be split into entries using the text_separator value. csv_delimiter : typing.Optional[str] The delimiter used for .csv uploads. eval_data : typing.Optional[core.File] See core.File for more documentation request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DatasetsCreateResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v1/datasets", method="POST", params={ "name": name, "type": type, "keep_original_file": keep_original_file, "skip_malformed_input": skip_malformed_input, "keep_fields": keep_fields, "optional_fields": optional_fields, "text_separator": text_separator, "csv_delimiter": csv_delimiter, }, data={}, files={ "data": data, **({"eval_data": eval_data} if eval_data is not None else {}), }, request_options=request_options, omit=OMIT, force_multipart=True, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsCreateResponse, construct_type( type_=DatasetsCreateResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def get_usage( self, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[DatasetsGetUsageResponse]: """ View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DatasetsGetUsageResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v1/datasets/usage", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsGetUsageResponse, construct_type( type_=DatasetsGetUsageResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def get( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[DatasetsGetResponse]: """ Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DatasetsGetResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/datasets/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsGetResponse, construct_type( type_=DatasetsGetResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[typing.Dict[str, typing.Any]]: """ Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[typing.Dict[str, typing.Any]] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/datasets/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( typing.Dict[str, typing.Any], construct_type( type_=typing.Dict[str, typing.Any], # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawDatasetsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def list( self, *, dataset_type: typing.Optional[str] = None, before: typing.Optional[dt.datetime] = None, after: typing.Optional[dt.datetime] = None, limit: typing.Optional[float] = None, offset: typing.Optional[float] = None, validation_status: typing.Optional[DatasetValidationStatus] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[DatasetsListResponse]: """ List datasets that have been created. Parameters ---------- dataset_type : typing.Optional[str] optional filter by dataset type before : typing.Optional[dt.datetime] optional filter before a date after : typing.Optional[dt.datetime] optional filter after a date limit : typing.Optional[float] optional limit to number of results offset : typing.Optional[float] optional offset to start of results validation_status : typing.Optional[DatasetValidationStatus] optional filter by validation status request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DatasetsListResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v1/datasets", method="GET", params={ "datasetType": dataset_type, "before": serialize_datetime(before) if before is not None else None, "after": serialize_datetime(after) if after is not None else None, "limit": limit, "offset": offset, "validationStatus": validation_status, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsListResponse, construct_type( type_=DatasetsListResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def create( self, *, name: str, type: DatasetType, data: core.File, keep_original_file: typing.Optional[bool] = None, skip_malformed_input: typing.Optional[bool] = None, keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None, text_separator: typing.Optional[str] = None, csv_delimiter: typing.Optional[str] = None, eval_data: typing.Optional[core.File] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[DatasetsCreateResponse]: """ Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information. Parameters ---------- name : str The name of the uploaded dataset. type : DatasetType The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API. data : core.File See core.File for more documentation keep_original_file : typing.Optional[bool] Indicates if the original file should be stored. skip_malformed_input : typing.Optional[bool] Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field. keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail. optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]] List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass. text_separator : typing.Optional[str] Raw .txt uploads will be split into entries using the text_separator value. csv_delimiter : typing.Optional[str] The delimiter used for .csv uploads. eval_data : typing.Optional[core.File] See core.File for more documentation request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DatasetsCreateResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v1/datasets", method="POST", params={ "name": name, "type": type, "keep_original_file": keep_original_file, "skip_malformed_input": skip_malformed_input, "keep_fields": keep_fields, "optional_fields": optional_fields, "text_separator": text_separator, "csv_delimiter": csv_delimiter, }, data={}, files={ "data": data, **({"eval_data": eval_data} if eval_data is not None else {}), }, request_options=request_options, omit=OMIT, force_multipart=True, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsCreateResponse, construct_type( type_=DatasetsCreateResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def get_usage( self, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[DatasetsGetUsageResponse]: """ View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DatasetsGetUsageResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v1/datasets/usage", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsGetUsageResponse, construct_type( type_=DatasetsGetUsageResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def get( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[DatasetsGetResponse]: """ Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DatasetsGetResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/datasets/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DatasetsGetResponse, construct_type( type_=DatasetsGetResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def delete( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[typing.Dict[str, typing.Any]]: """ Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually. Parameters ---------- id : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[typing.Dict[str, typing.Any]] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/datasets/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( typing.Dict[str, typing.Any], construct_type( type_=typing.Dict[str, typing.Any], # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/datasets/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .datasets_create_response import DatasetsCreateResponse from .datasets_get_response import DatasetsGetResponse from .datasets_get_usage_response import DatasetsGetUsageResponse from .datasets_list_response import DatasetsListResponse _dynamic_imports: typing.Dict[str, str] = { "DatasetsCreateResponse": ".datasets_create_response", "DatasetsGetResponse": ".datasets_get_response", "DatasetsGetUsageResponse": ".datasets_get_usage_response", "DatasetsListResponse": ".datasets_list_response", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["DatasetsCreateResponse", "DatasetsGetResponse", "DatasetsGetUsageResponse", "DatasetsListResponse"] ================================================ FILE: src/cohere/datasets/types/datasets_create_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel class DatasetsCreateResponse(UncheckedBaseModel): id: typing.Optional[str] = pydantic.Field(default=None) """ The dataset ID """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/datasets/types/datasets_get_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from ...types.dataset import Dataset class DatasetsGetResponse(UncheckedBaseModel): dataset: Dataset if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/datasets/types/datasets_get_usage_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel class DatasetsGetUsageResponse(UncheckedBaseModel): organization_usage: typing.Optional[int] = pydantic.Field(default=None) """ The total number of bytes used by the organization. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/datasets/types/datasets_list_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from ...types.dataset import Dataset class DatasetsListResponse(UncheckedBaseModel): datasets: typing.Optional[typing.List[Dataset]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/embed_jobs/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import CreateEmbedJobRequestTruncate _dynamic_imports: typing.Dict[str, str] = {"CreateEmbedJobRequestTruncate": ".types"} def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["CreateEmbedJobRequestTruncate"] ================================================ FILE: src/cohere/embed_jobs/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from ..types.create_embed_job_response import CreateEmbedJobResponse from ..types.embed_input_type import EmbedInputType from ..types.embed_job import EmbedJob from ..types.embedding_type import EmbeddingType from ..types.list_embed_job_response import ListEmbedJobResponse from .raw_client import AsyncRawEmbedJobsClient, RawEmbedJobsClient from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class EmbedJobsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawEmbedJobsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawEmbedJobsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawEmbedJobsClient """ return self._raw_client def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> ListEmbedJobResponse: """ The list embed job endpoint allows users to view all embed jobs history for that specific user. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListEmbedJobResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.embed_jobs.list() """ _response = self._raw_client.list(request_options=request_options) return _response.data def create( self, *, model: str, dataset_id: str, input_type: EmbedInputType, name: typing.Optional[str] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> CreateEmbedJobResponse: """ This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings. Parameters ---------- model : str ID of the embedding model. Available models and corresponding embedding dimensions: - `embed-english-v3.0` : 1024 - `embed-multilingual-v3.0` : 1024 - `embed-english-light-v3.0` : 384 - `embed-multilingual-light-v3.0` : 384 dataset_id : str ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated` input_type : EmbedInputType name : typing.Optional[str] The name of the embed job. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions. * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions. truncate : typing.Optional[CreateEmbedJobRequestTruncate] One of `START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateEmbedJobResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.embed_jobs.create( model="model", dataset_id="dataset_id", input_type="search_document", ) """ _response = self._raw_client.create( model=model, dataset_id=dataset_id, input_type=input_type, name=name, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) return _response.data def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> EmbedJob: """ This API retrieves the details about an embed job started by the same user. Parameters ---------- id : str The ID of the embed job to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedJob OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.embed_jobs.get( id="id", ) """ _response = self._raw_client.get(id, request_options=request_options) return _response.data def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> None: """ This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation. Parameters ---------- id : str The ID of the embed job to cancel. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- None Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.embed_jobs.cancel( id="id", ) """ _response = self._raw_client.cancel(id, request_options=request_options) return _response.data class AsyncEmbedJobsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawEmbedJobsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawEmbedJobsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawEmbedJobsClient """ return self._raw_client async def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> ListEmbedJobResponse: """ The list embed job endpoint allows users to view all embed jobs history for that specific user. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListEmbedJobResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.embed_jobs.list() asyncio.run(main()) """ _response = await self._raw_client.list(request_options=request_options) return _response.data async def create( self, *, model: str, dataset_id: str, input_type: EmbedInputType, name: typing.Optional[str] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> CreateEmbedJobResponse: """ This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings. Parameters ---------- model : str ID of the embedding model. Available models and corresponding embedding dimensions: - `embed-english-v3.0` : 1024 - `embed-multilingual-v3.0` : 1024 - `embed-english-light-v3.0` : 384 - `embed-multilingual-light-v3.0` : 384 dataset_id : str ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated` input_type : EmbedInputType name : typing.Optional[str] The name of the embed job. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions. * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions. truncate : typing.Optional[CreateEmbedJobRequestTruncate] One of `START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateEmbedJobResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.embed_jobs.create( model="model", dataset_id="dataset_id", input_type="search_document", ) asyncio.run(main()) """ _response = await self._raw_client.create( model=model, dataset_id=dataset_id, input_type=input_type, name=name, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) return _response.data async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> EmbedJob: """ This API retrieves the details about an embed job started by the same user. Parameters ---------- id : str The ID of the embed job to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedJob OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.embed_jobs.get( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.get(id, request_options=request_options) return _response.data async def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> None: """ This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation. Parameters ---------- id : str The ID of the embed job to cancel. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- None Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.embed_jobs.cancel( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.cancel(id, request_options=request_options) return _response.data ================================================ FILE: src/cohere/embed_jobs/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.client_closed_request_error import ClientClosedRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.gateway_timeout_error import GatewayTimeoutError from ..errors.internal_server_error import InternalServerError from ..errors.invalid_token_error import InvalidTokenError from ..errors.not_found_error import NotFoundError from ..errors.not_implemented_error import NotImplementedError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.too_many_requests_error import TooManyRequestsError from ..errors.unauthorized_error import UnauthorizedError from ..errors.unprocessable_entity_error import UnprocessableEntityError from ..types.create_embed_job_response import CreateEmbedJobResponse from ..types.embed_input_type import EmbedInputType from ..types.embed_job import EmbedJob from ..types.embedding_type import EmbeddingType from ..types.list_embed_job_response import ListEmbedJobResponse from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawEmbedJobsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[ListEmbedJobResponse]: """ The list embed job endpoint allows users to view all embed jobs history for that specific user. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListEmbedJobResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/embed-jobs", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListEmbedJobResponse, construct_type( type_=ListEmbedJobResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def create( self, *, model: str, dataset_id: str, input_type: EmbedInputType, name: typing.Optional[str] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[CreateEmbedJobResponse]: """ This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings. Parameters ---------- model : str ID of the embedding model. Available models and corresponding embedding dimensions: - `embed-english-v3.0` : 1024 - `embed-multilingual-v3.0` : 1024 - `embed-english-light-v3.0` : 384 - `embed-multilingual-light-v3.0` : 384 dataset_id : str ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated` input_type : EmbedInputType name : typing.Optional[str] The name of the embed job. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions. * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions. truncate : typing.Optional[CreateEmbedJobRequestTruncate] One of `START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CreateEmbedJobResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/embed-jobs", method="POST", json={ "model": model, "dataset_id": dataset_id, "input_type": input_type, "name": name, "embedding_types": embedding_types, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateEmbedJobResponse, construct_type( type_=CreateEmbedJobResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[EmbedJob]: """ This API retrieves the details about an embed job started by the same user. Parameters ---------- id : str The ID of the embed job to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[EmbedJob] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/embed-jobs/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedJob, construct_type( type_=EmbedJob, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[None]: """ This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation. Parameters ---------- id : str The ID of the embed job to cancel. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[None] """ _response = self._client_wrapper.httpx_client.request( f"v1/embed-jobs/{jsonable_encoder(id)}/cancel", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: return HttpResponse(response=_response, data=None) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawEmbedJobsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def list( self, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[ListEmbedJobResponse]: """ The list embed job endpoint allows users to view all embed jobs history for that specific user. Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListEmbedJobResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/embed-jobs", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListEmbedJobResponse, construct_type( type_=ListEmbedJobResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def create( self, *, model: str, dataset_id: str, input_type: EmbedInputType, name: typing.Optional[str] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[CreateEmbedJobResponse]: """ This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings. Parameters ---------- model : str ID of the embedding model. Available models and corresponding embedding dimensions: - `embed-english-v3.0` : 1024 - `embed-multilingual-v3.0` : 1024 - `embed-english-light-v3.0` : 384 - `embed-multilingual-light-v3.0` : 384 dataset_id : str ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated` input_type : EmbedInputType name : typing.Optional[str] The name of the embed job. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions. * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions. truncate : typing.Optional[CreateEmbedJobRequestTruncate] One of `START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CreateEmbedJobResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/embed-jobs", method="POST", json={ "model": model, "dataset_id": dataset_id, "input_type": input_type, "name": name, "embedding_types": embedding_types, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateEmbedJobResponse, construct_type( type_=CreateEmbedJobResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def get( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[EmbedJob]: """ This API retrieves the details about an embed job started by the same user. Parameters ---------- id : str The ID of the embed job to retrieve. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[EmbedJob] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/embed-jobs/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedJob, construct_type( type_=EmbedJob, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def cancel( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[None]: """ This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation. Parameters ---------- id : str The ID of the embed job to cancel. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[None] """ _response = await self._client_wrapper.httpx_client.request( f"v1/embed-jobs/{jsonable_encoder(id)}/cancel", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: return AsyncHttpResponse(response=_response, data=None) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/embed_jobs/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .create_embed_job_request_truncate import CreateEmbedJobRequestTruncate _dynamic_imports: typing.Dict[str, str] = {"CreateEmbedJobRequestTruncate": ".create_embed_job_request_truncate"} def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = ["CreateEmbedJobRequestTruncate"] ================================================ FILE: src/cohere/embed_jobs/types/create_embed_job_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing CreateEmbedJobRequestTruncate = typing.Union[typing.Literal["START", "END"], typing.Any] ================================================ FILE: src/cohere/environment.py ================================================ # This file was auto-generated by Fern from our API Definition. import enum class ClientEnvironment(enum.Enum): PRODUCTION = "https://api.cohere.com" ================================================ FILE: src/cohere/errors/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .bad_request_error import BadRequestError from .client_closed_request_error import ClientClosedRequestError from .forbidden_error import ForbiddenError from .gateway_timeout_error import GatewayTimeoutError from .internal_server_error import InternalServerError from .invalid_token_error import InvalidTokenError from .not_found_error import NotFoundError from .not_implemented_error import NotImplementedError from .service_unavailable_error import ServiceUnavailableError from .too_many_requests_error import TooManyRequestsError from .unauthorized_error import UnauthorizedError from .unprocessable_entity_error import UnprocessableEntityError _dynamic_imports: typing.Dict[str, str] = { "BadRequestError": ".bad_request_error", "ClientClosedRequestError": ".client_closed_request_error", "ForbiddenError": ".forbidden_error", "GatewayTimeoutError": ".gateway_timeout_error", "InternalServerError": ".internal_server_error", "InvalidTokenError": ".invalid_token_error", "NotFoundError": ".not_found_error", "NotImplementedError": ".not_implemented_error", "ServiceUnavailableError": ".service_unavailable_error", "TooManyRequestsError": ".too_many_requests_error", "UnauthorizedError": ".unauthorized_error", "UnprocessableEntityError": ".unprocessable_entity_error", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "BadRequestError", "ClientClosedRequestError", "ForbiddenError", "GatewayTimeoutError", "InternalServerError", "InvalidTokenError", "NotFoundError", "NotImplementedError", "ServiceUnavailableError", "TooManyRequestsError", "UnauthorizedError", "UnprocessableEntityError", ] ================================================ FILE: src/cohere/errors/bad_request_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class BadRequestError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=400, headers=headers, body=body) ================================================ FILE: src/cohere/errors/client_closed_request_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class ClientClosedRequestError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=499, headers=headers, body=body) ================================================ FILE: src/cohere/errors/forbidden_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class ForbiddenError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=403, headers=headers, body=body) ================================================ FILE: src/cohere/errors/gateway_timeout_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class GatewayTimeoutError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=504, headers=headers, body=body) ================================================ FILE: src/cohere/errors/internal_server_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class InternalServerError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=500, headers=headers, body=body) ================================================ FILE: src/cohere/errors/invalid_token_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class InvalidTokenError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=498, headers=headers, body=body) ================================================ FILE: src/cohere/errors/not_found_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class NotFoundError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=404, headers=headers, body=body) ================================================ FILE: src/cohere/errors/not_implemented_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class NotImplementedError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=501, headers=headers, body=body) ================================================ FILE: src/cohere/errors/service_unavailable_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class ServiceUnavailableError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=503, headers=headers, body=body) ================================================ FILE: src/cohere/errors/too_many_requests_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class TooManyRequestsError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=429, headers=headers, body=body) ================================================ FILE: src/cohere/errors/unauthorized_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class UnauthorizedError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=401, headers=headers, body=body) ================================================ FILE: src/cohere/errors/unprocessable_entity_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.api_error import ApiError class UnprocessableEntityError(ApiError): def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None): super().__init__(status_code=422, headers=headers, body=body) ================================================ FILE: src/cohere/finetuning/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from . import finetuning from .finetuning import ( BaseModel, BaseType, CreateFinetunedModelResponse, DeleteFinetunedModelResponse, Event, FinetunedModel, GetFinetunedModelResponse, Hyperparameters, ListEventsResponse, ListFinetunedModelsResponse, ListTrainingStepMetricsResponse, LoraTargetModules, Settings, Status, Strategy, TrainingStepMetrics, UpdateFinetunedModelResponse, WandbConfig, ) _dynamic_imports: typing.Dict[str, str] = { "BaseModel": ".finetuning", "BaseType": ".finetuning", "CreateFinetunedModelResponse": ".finetuning", "DeleteFinetunedModelResponse": ".finetuning", "Event": ".finetuning", "FinetunedModel": ".finetuning", "GetFinetunedModelResponse": ".finetuning", "Hyperparameters": ".finetuning", "ListEventsResponse": ".finetuning", "ListFinetunedModelsResponse": ".finetuning", "ListTrainingStepMetricsResponse": ".finetuning", "LoraTargetModules": ".finetuning", "Settings": ".finetuning", "Status": ".finetuning", "Strategy": ".finetuning", "TrainingStepMetrics": ".finetuning", "UpdateFinetunedModelResponse": ".finetuning", "WandbConfig": ".finetuning", "finetuning": ".finetuning", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "BaseModel", "BaseType", "CreateFinetunedModelResponse", "DeleteFinetunedModelResponse", "Event", "FinetunedModel", "GetFinetunedModelResponse", "Hyperparameters", "ListEventsResponse", "ListFinetunedModelsResponse", "ListTrainingStepMetricsResponse", "LoraTargetModules", "Settings", "Status", "Strategy", "TrainingStepMetrics", "UpdateFinetunedModelResponse", "WandbConfig", "finetuning", ] ================================================ FILE: src/cohere/finetuning/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from .finetuning.types.create_finetuned_model_response import CreateFinetunedModelResponse from .finetuning.types.delete_finetuned_model_response import DeleteFinetunedModelResponse from .finetuning.types.finetuned_model import FinetunedModel from .finetuning.types.get_finetuned_model_response import GetFinetunedModelResponse from .finetuning.types.list_events_response import ListEventsResponse from .finetuning.types.list_finetuned_models_response import ListFinetunedModelsResponse from .finetuning.types.list_training_step_metrics_response import ListTrainingStepMetricsResponse from .finetuning.types.settings import Settings from .finetuning.types.update_finetuned_model_response import UpdateFinetunedModelResponse from .raw_client import AsyncRawFinetuningClient, RawFinetuningClient # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class FinetuningClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawFinetuningClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawFinetuningClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawFinetuningClient """ return self._raw_client def list_finetuned_models( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListFinetunedModelsResponse: """ Returns a list of fine-tuned models that the user has access to. Parameters ---------- page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListFinetunedModelsResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.list_finetuned_models( page_size=1, page_token="page_token", order_by="order_by", ) """ _response = self._raw_client.list_finetuned_models( page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options ) return _response.data def create_finetuned_model( self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None ) -> CreateFinetunedModelResponse: """ Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete. Parameters ---------- request : FinetunedModel request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateFinetunedModelResponse A successful response. Examples -------- from cohere import Client from cohere.finetuning.finetuning import BaseModel, FinetunedModel, Settings client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.create_finetuned_model( request=FinetunedModel( name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ), ) """ _response = self._raw_client.create_finetuned_model(request=request, request_options=request_options) return _response.data def get_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> GetFinetunedModelResponse: """ Retrieve a fine-tuned model by its ID. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetFinetunedModelResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.get_finetuned_model( id="id", ) """ _response = self._raw_client.get_finetuned_model(id, request_options=request_options) return _response.data def delete_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> DeleteFinetunedModelResponse: """ Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use. This operation is irreversible. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DeleteFinetunedModelResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.delete_finetuned_model( id="id", ) """ _response = self._raw_client.delete_finetuned_model(id, request_options=request_options) return _response.data def update_finetuned_model( self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None ) -> UpdateFinetunedModelResponse: """ Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body. Parameters ---------- id : str FinetunedModel ID. name : str FinetunedModel name (e.g. `foobar`). settings : Settings FinetunedModel settings such as dataset, hyperparameters... request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- UpdateFinetunedModelResponse A successful response. Examples -------- from cohere import Client from cohere.finetuning.finetuning import BaseModel, Settings client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.update_finetuned_model( id="id", name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ) """ _response = self._raw_client.update_finetuned_model( id, name=name, settings=settings, request_options=request_options ) return _response.data def list_events( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListEventsResponse: """ Returns a list of events that occurred during the life-cycle of the fine-tuned model. The events are ordered by creation time, with the most recent event first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListEventsResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.list_events( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", order_by="order_by", ) """ _response = self._raw_client.list_events( finetuned_model_id, page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options, ) return _response.data def list_training_step_metrics( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListTrainingStepMetricsResponse: """ Returns a list of metrics measured during the training of a fine-tuned model. The metrics are ordered by step number, with the most recent step first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListTrainingStepMetricsResponse A successful response. Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.finetuning.list_training_step_metrics( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", ) """ _response = self._raw_client.list_training_step_metrics( finetuned_model_id, page_size=page_size, page_token=page_token, request_options=request_options ) return _response.data class AsyncFinetuningClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawFinetuningClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawFinetuningClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawFinetuningClient """ return self._raw_client async def list_finetuned_models( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListFinetunedModelsResponse: """ Returns a list of fine-tuned models that the user has access to. Parameters ---------- page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListFinetunedModelsResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.list_finetuned_models( page_size=1, page_token="page_token", order_by="order_by", ) asyncio.run(main()) """ _response = await self._raw_client.list_finetuned_models( page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options ) return _response.data async def create_finetuned_model( self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None ) -> CreateFinetunedModelResponse: """ Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete. Parameters ---------- request : FinetunedModel request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- CreateFinetunedModelResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient from cohere.finetuning.finetuning import BaseModel, FinetunedModel, Settings client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.create_finetuned_model( request=FinetunedModel( name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ), ) asyncio.run(main()) """ _response = await self._raw_client.create_finetuned_model(request=request, request_options=request_options) return _response.data async def get_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> GetFinetunedModelResponse: """ Retrieve a fine-tuned model by its ID. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetFinetunedModelResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.get_finetuned_model( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.get_finetuned_model(id, request_options=request_options) return _response.data async def delete_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> DeleteFinetunedModelResponse: """ Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use. This operation is irreversible. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- DeleteFinetunedModelResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.delete_finetuned_model( id="id", ) asyncio.run(main()) """ _response = await self._raw_client.delete_finetuned_model(id, request_options=request_options) return _response.data async def update_finetuned_model( self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None ) -> UpdateFinetunedModelResponse: """ Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body. Parameters ---------- id : str FinetunedModel ID. name : str FinetunedModel name (e.g. `foobar`). settings : Settings FinetunedModel settings such as dataset, hyperparameters... request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- UpdateFinetunedModelResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient from cohere.finetuning.finetuning import BaseModel, Settings client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.update_finetuned_model( id="id", name="name", settings=Settings( base_model=BaseModel( base_type="BASE_TYPE_UNSPECIFIED", ), dataset_id="dataset_id", ), ) asyncio.run(main()) """ _response = await self._raw_client.update_finetuned_model( id, name=name, settings=settings, request_options=request_options ) return _response.data async def list_events( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListEventsResponse: """ Returns a list of events that occurred during the life-cycle of the fine-tuned model. The events are ordered by creation time, with the most recent event first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListEventsResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.list_events( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", order_by="order_by", ) asyncio.run(main()) """ _response = await self._raw_client.list_events( finetuned_model_id, page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options, ) return _response.data async def list_training_step_metrics( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListTrainingStepMetricsResponse: """ Returns a list of metrics measured during the training of a fine-tuned model. The metrics are ordered by step number, with the most recent step first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListTrainingStepMetricsResponse A successful response. Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.finetuning.list_training_step_metrics( finetuned_model_id="finetuned_model_id", page_size=1, page_token="page_token", ) asyncio.run(main()) """ _response = await self._raw_client.list_training_step_metrics( finetuned_model_id, page_size=page_size, page_token=page_token, request_options=request_options ) return _response.data ================================================ FILE: src/cohere/finetuning/finetuning/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import ( BaseModel, BaseType, CreateFinetunedModelResponse, DeleteFinetunedModelResponse, Event, FinetunedModel, GetFinetunedModelResponse, Hyperparameters, ListEventsResponse, ListFinetunedModelsResponse, ListTrainingStepMetricsResponse, LoraTargetModules, Settings, Status, Strategy, TrainingStepMetrics, UpdateFinetunedModelResponse, WandbConfig, ) _dynamic_imports: typing.Dict[str, str] = { "BaseModel": ".types", "BaseType": ".types", "CreateFinetunedModelResponse": ".types", "DeleteFinetunedModelResponse": ".types", "Event": ".types", "FinetunedModel": ".types", "GetFinetunedModelResponse": ".types", "Hyperparameters": ".types", "ListEventsResponse": ".types", "ListFinetunedModelsResponse": ".types", "ListTrainingStepMetricsResponse": ".types", "LoraTargetModules": ".types", "Settings": ".types", "Status": ".types", "Strategy": ".types", "TrainingStepMetrics": ".types", "UpdateFinetunedModelResponse": ".types", "WandbConfig": ".types", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "BaseModel", "BaseType", "CreateFinetunedModelResponse", "DeleteFinetunedModelResponse", "Event", "FinetunedModel", "GetFinetunedModelResponse", "Hyperparameters", "ListEventsResponse", "ListFinetunedModelsResponse", "ListTrainingStepMetricsResponse", "LoraTargetModules", "Settings", "Status", "Strategy", "TrainingStepMetrics", "UpdateFinetunedModelResponse", "WandbConfig", ] ================================================ FILE: src/cohere/finetuning/finetuning/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .base_model import BaseModel from .base_type import BaseType from .create_finetuned_model_response import CreateFinetunedModelResponse from .delete_finetuned_model_response import DeleteFinetunedModelResponse from .event import Event from .finetuned_model import FinetunedModel from .get_finetuned_model_response import GetFinetunedModelResponse from .hyperparameters import Hyperparameters from .list_events_response import ListEventsResponse from .list_finetuned_models_response import ListFinetunedModelsResponse from .list_training_step_metrics_response import ListTrainingStepMetricsResponse from .lora_target_modules import LoraTargetModules from .settings import Settings from .status import Status from .strategy import Strategy from .training_step_metrics import TrainingStepMetrics from .update_finetuned_model_response import UpdateFinetunedModelResponse from .wandb_config import WandbConfig _dynamic_imports: typing.Dict[str, str] = { "BaseModel": ".base_model", "BaseType": ".base_type", "CreateFinetunedModelResponse": ".create_finetuned_model_response", "DeleteFinetunedModelResponse": ".delete_finetuned_model_response", "Event": ".event", "FinetunedModel": ".finetuned_model", "GetFinetunedModelResponse": ".get_finetuned_model_response", "Hyperparameters": ".hyperparameters", "ListEventsResponse": ".list_events_response", "ListFinetunedModelsResponse": ".list_finetuned_models_response", "ListTrainingStepMetricsResponse": ".list_training_step_metrics_response", "LoraTargetModules": ".lora_target_modules", "Settings": ".settings", "Status": ".status", "Strategy": ".strategy", "TrainingStepMetrics": ".training_step_metrics", "UpdateFinetunedModelResponse": ".update_finetuned_model_response", "WandbConfig": ".wandb_config", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "BaseModel", "BaseType", "CreateFinetunedModelResponse", "DeleteFinetunedModelResponse", "Event", "FinetunedModel", "GetFinetunedModelResponse", "Hyperparameters", "ListEventsResponse", "ListFinetunedModelsResponse", "ListTrainingStepMetricsResponse", "LoraTargetModules", "Settings", "Status", "Strategy", "TrainingStepMetrics", "UpdateFinetunedModelResponse", "WandbConfig", ] ================================================ FILE: src/cohere/finetuning/finetuning/types/base_model.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .base_type import BaseType from .strategy import Strategy class BaseModel(UncheckedBaseModel): """ The base model used for fine-tuning. """ name: typing.Optional[str] = pydantic.Field(default=None) """ The name of the base model. """ version: typing.Optional[str] = pydantic.Field(default=None) """ read-only. The version of the base model. """ base_type: BaseType = pydantic.Field() """ The type of the base model. """ strategy: typing.Optional[Strategy] = pydantic.Field(default=None) """ Deprecated: The fine-tuning strategy. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/base_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing BaseType = typing.Union[ typing.Literal[ "BASE_TYPE_UNSPECIFIED", "BASE_TYPE_GENERATIVE", "BASE_TYPE_CLASSIFICATION", "BASE_TYPE_RERANK", "BASE_TYPE_CHAT", ], typing.Any, ] ================================================ FILE: src/cohere/finetuning/finetuning/types/create_finetuned_model_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .finetuned_model import FinetunedModel class CreateFinetunedModelResponse(UncheckedBaseModel): """ Response to request to create a fine-tuned model. """ finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None) """ Information about the fine-tuned model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/delete_finetuned_model_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing DeleteFinetunedModelResponse = typing.Dict[str, typing.Any] """ Response to request to delete a fine-tuned model. """ ================================================ FILE: src/cohere/finetuning/finetuning/types/event.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .status import Status class Event(UncheckedBaseModel): """ A change in status of a fine-tuned model. """ user_id: typing.Optional[str] = pydantic.Field(default=None) """ ID of the user who initiated the event. Empty if initiated by the system. """ status: typing.Optional[Status] = pydantic.Field(default=None) """ Status of the fine-tuned model. """ created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ Timestamp when the event happened. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/finetuned_model.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .settings import Settings from .status import Status class FinetunedModel(UncheckedBaseModel): """ This resource represents a fine-tuned model. """ id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. FinetunedModel ID. """ name: str = pydantic.Field() """ FinetunedModel name (e.g. `foobar`). """ creator_id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. User ID of the creator. """ organization_id: typing.Optional[str] = pydantic.Field(default=None) """ read-only. Organization ID. """ settings: Settings = pydantic.Field() """ FinetunedModel settings such as dataset, hyperparameters... """ status: typing.Optional[Status] = pydantic.Field(default=None) """ read-only. Current stage in the life-cycle of the fine-tuned model. """ created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Creation timestamp. """ updated_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Latest update timestamp. """ completed_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Timestamp for the completed fine-tuning. """ last_used: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ read-only. Deprecated: Timestamp for the latest request to this fine-tuned model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/get_finetuned_model_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .finetuned_model import FinetunedModel class GetFinetunedModelResponse(UncheckedBaseModel): """ Response to a request to get a fine-tuned model. """ finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None) """ Information about the fine-tuned model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/hyperparameters.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .lora_target_modules import LoraTargetModules class Hyperparameters(UncheckedBaseModel): """ The fine-tuning hyperparameters. """ early_stopping_patience: typing.Optional[int] = pydantic.Field(default=None) """ Stops training if the loss metric does not improve beyond the value of `early_stopping_threshold` after this many times of evaluation. """ early_stopping_threshold: typing.Optional[float] = pydantic.Field(default=None) """ How much the loss must improve to prevent early stopping. """ train_batch_size: typing.Optional[int] = pydantic.Field(default=None) """ The batch size is the number of training examples included in a single training pass. """ train_epochs: typing.Optional[int] = pydantic.Field(default=None) """ The number of epochs to train for. """ learning_rate: typing.Optional[float] = pydantic.Field(default=None) """ The learning rate to be used during training. """ lora_alpha: typing.Optional[int] = pydantic.Field(default=None) """ Controls the scaling factor for LoRA updates. Higher values make the updates more impactful. """ lora_rank: typing.Optional[int] = pydantic.Field(default=None) """ Specifies the rank for low-rank matrices. Lower ranks reduce parameters but may limit model flexibility. """ lora_target_modules: typing.Optional[LoraTargetModules] = pydantic.Field(default=None) """ The combination of LoRA modules to target. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/list_events_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .event import Event class ListEventsResponse(UncheckedBaseModel): """ Response to a request to list events of a fine-tuned model. """ events: typing.Optional[typing.List[Event]] = pydantic.Field(default=None) """ List of events for the fine-tuned model. """ next_page_token: typing.Optional[str] = pydantic.Field(default=None) """ Pagination token to retrieve the next page of results. If the value is "", it means no further results for the request. """ total_size: typing.Optional[int] = pydantic.Field(default=None) """ Total count of results. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/list_finetuned_models_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .finetuned_model import FinetunedModel class ListFinetunedModelsResponse(UncheckedBaseModel): """ Response to a request to list fine-tuned models. """ finetuned_models: typing.Optional[typing.List[FinetunedModel]] = pydantic.Field(default=None) """ List of fine-tuned models matching the request. """ next_page_token: typing.Optional[str] = pydantic.Field(default=None) """ Pagination token to retrieve the next page of results. If the value is "", it means no further results for the request. """ total_size: typing.Optional[int] = pydantic.Field(default=None) """ Total count of results. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/list_training_step_metrics_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .training_step_metrics import TrainingStepMetrics class ListTrainingStepMetricsResponse(UncheckedBaseModel): """ Response to a request to list training-step metrics of a fine-tuned model. """ step_metrics: typing.Optional[typing.List[TrainingStepMetrics]] = pydantic.Field(default=None) """ The metrics for each step the evaluation was run on. """ next_page_token: typing.Optional[str] = pydantic.Field(default=None) """ Pagination token to retrieve the next page of results. If the value is "", it means no further results for the request. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/lora_target_modules.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing LoraTargetModules = typing.Union[ typing.Literal[ "LORA_TARGET_MODULES_UNSPECIFIED", "LORA_TARGET_MODULES_QV", "LORA_TARGET_MODULES_QKVO", "LORA_TARGET_MODULES_QKVO_FFN", ], typing.Any, ] ================================================ FILE: src/cohere/finetuning/finetuning/types/settings.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .base_model import BaseModel from .hyperparameters import Hyperparameters from .wandb_config import WandbConfig class Settings(UncheckedBaseModel): """ The configuration used for fine-tuning. """ base_model: BaseModel = pydantic.Field() """ The base model to fine-tune. """ dataset_id: str = pydantic.Field() """ The data used for training and evaluating the fine-tuned model. """ hyperparameters: typing.Optional[Hyperparameters] = pydantic.Field(default=None) """ Fine-tuning hyper-parameters. """ multi_label: typing.Optional[bool] = pydantic.Field(default=None) """ read-only. Whether the model is single-label or multi-label (only for classification). """ wandb: typing.Optional[WandbConfig] = pydantic.Field(default=None) """ The Weights & Biases configuration (Chat fine-tuning only). """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/status.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing Status = typing.Union[ typing.Literal[ "STATUS_UNSPECIFIED", "STATUS_FINETUNING", "STATUS_DEPLOYING_API", "STATUS_READY", "STATUS_FAILED", "STATUS_DELETED", "STATUS_TEMPORARILY_OFFLINE", "STATUS_PAUSED", "STATUS_QUEUED", ], typing.Any, ] ================================================ FILE: src/cohere/finetuning/finetuning/types/strategy.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing Strategy = typing.Union[typing.Literal["STRATEGY_UNSPECIFIED", "STRATEGY_VANILLA", "STRATEGY_TFEW"], typing.Any] ================================================ FILE: src/cohere/finetuning/finetuning/types/training_step_metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel class TrainingStepMetrics(UncheckedBaseModel): """ The evaluation metrics at a given step of the training of a fine-tuned model. """ created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None) """ Creation timestamp. """ step_number: typing.Optional[int] = pydantic.Field(default=None) """ Step number. """ metrics: typing.Optional[typing.Dict[str, float]] = pydantic.Field(default=None) """ Map of names and values for each evaluation metrics. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/update_finetuned_model_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel from .finetuned_model import FinetunedModel class UpdateFinetunedModelResponse(UncheckedBaseModel): """ Response to a request to update a fine-tuned model. """ finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None) """ Information about the fine-tuned model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/finetuning/types/wandb_config.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ....core.pydantic_utilities import IS_PYDANTIC_V2 from ....core.unchecked_base_model import UncheckedBaseModel class WandbConfig(UncheckedBaseModel): """ The Weights & Biases configuration. """ project: str = pydantic.Field() """ The WandB project name to be used during training. """ api_key: str = pydantic.Field() """ The WandB API key to be used during training. """ entity: typing.Optional[str] = pydantic.Field(default=None) """ The WandB entity name to be used during training. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/finetuning/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.serialization import convert_and_respect_annotation_metadata from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.internal_server_error import InternalServerError from ..errors.not_found_error import NotFoundError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.unauthorized_error import UnauthorizedError from .finetuning.types.create_finetuned_model_response import CreateFinetunedModelResponse from .finetuning.types.delete_finetuned_model_response import DeleteFinetunedModelResponse from .finetuning.types.finetuned_model import FinetunedModel from .finetuning.types.get_finetuned_model_response import GetFinetunedModelResponse from .finetuning.types.list_events_response import ListEventsResponse from .finetuning.types.list_finetuned_models_response import ListFinetunedModelsResponse from .finetuning.types.list_training_step_metrics_response import ListTrainingStepMetricsResponse from .finetuning.types.settings import Settings from .finetuning.types.update_finetuned_model_response import UpdateFinetunedModelResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawFinetuningClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def list_finetuned_models( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListFinetunedModelsResponse]: """ Returns a list of fine-tuned models that the user has access to. Parameters ---------- page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListFinetunedModelsResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v1/finetuning/finetuned-models", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListFinetunedModelsResponse, construct_type( type_=ListFinetunedModelsResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def create_finetuned_model( self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[CreateFinetunedModelResponse]: """ Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete. Parameters ---------- request : FinetunedModel request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CreateFinetunedModelResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( "v1/finetuning/finetuned-models", method="POST", json=convert_and_respect_annotation_metadata(object_=request, annotation=FinetunedModel, direction="write"), headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateFinetunedModelResponse, construct_type( type_=CreateFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def get_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[GetFinetunedModelResponse]: """ Retrieve a fine-tuned model by its ID. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[GetFinetunedModelResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetFinetunedModelResponse, construct_type( type_=GetFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def delete_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[DeleteFinetunedModelResponse]: """ Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use. This operation is irreversible. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DeleteFinetunedModelResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DeleteFinetunedModelResponse, construct_type( type_=DeleteFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def update_finetuned_model( self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[UpdateFinetunedModelResponse]: """ Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body. Parameters ---------- id : str FinetunedModel ID. name : str FinetunedModel name (e.g. `foobar`). settings : Settings FinetunedModel settings such as dataset, hyperparameters... request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[UpdateFinetunedModelResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="PATCH", json={ "name": name, "settings": convert_and_respect_annotation_metadata( object_=settings, annotation=Settings, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( UpdateFinetunedModelResponse, construct_type( type_=UpdateFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def list_events( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListEventsResponse]: """ Returns a list of events that occurred during the life-cycle of the fine-tuned model. The events are ordered by creation time, with the most recent event first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListEventsResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/events", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListEventsResponse, construct_type( type_=ListEventsResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def list_training_step_metrics( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListTrainingStepMetricsResponse]: """ Returns a list of metrics measured during the training of a fine-tuned model. The metrics are ordered by step number, with the most recent step first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListTrainingStepMetricsResponse] A successful response. """ _response = self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/training-step-metrics", method="GET", params={ "page_size": page_size, "page_token": page_token, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListTrainingStepMetricsResponse, construct_type( type_=ListTrainingStepMetricsResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawFinetuningClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def list_finetuned_models( self, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListFinetunedModelsResponse]: """ Returns a list of fine-tuned models that the user has access to. Parameters ---------- page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListFinetunedModelsResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v1/finetuning/finetuned-models", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListFinetunedModelsResponse, construct_type( type_=ListFinetunedModelsResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def create_finetuned_model( self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[CreateFinetunedModelResponse]: """ Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete. Parameters ---------- request : FinetunedModel request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CreateFinetunedModelResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( "v1/finetuning/finetuned-models", method="POST", json=convert_and_respect_annotation_metadata(object_=request, annotation=FinetunedModel, direction="write"), headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CreateFinetunedModelResponse, construct_type( type_=CreateFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def get_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[GetFinetunedModelResponse]: """ Retrieve a fine-tuned model by its ID. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[GetFinetunedModelResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetFinetunedModelResponse, construct_type( type_=GetFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def delete_finetuned_model( self, id: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[DeleteFinetunedModelResponse]: """ Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use. This operation is irreversible. Parameters ---------- id : str The fine-tuned model ID. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DeleteFinetunedModelResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="DELETE", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DeleteFinetunedModelResponse, construct_type( type_=DeleteFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def update_finetuned_model( self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[UpdateFinetunedModelResponse]: """ Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body. Parameters ---------- id : str FinetunedModel ID. name : str FinetunedModel name (e.g. `foobar`). settings : Settings FinetunedModel settings such as dataset, hyperparameters... request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[UpdateFinetunedModelResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}", method="PATCH", json={ "name": name, "settings": convert_and_respect_annotation_metadata( object_=settings, annotation=Settings, direction="write" ), }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( UpdateFinetunedModelResponse, construct_type( type_=UpdateFinetunedModelResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def list_events( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, order_by: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListEventsResponse]: """ Returns a list of events that occurred during the life-cycle of the fine-tuned model. The events are ordered by creation time, with the most recent event first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. order_by : typing.Optional[str] Comma separated list of fields. For example: "created_at,name". The default sorting order is ascending. To specify descending order for a field, append " desc" to the field name. For example: "created_at desc,name". Supported sorting fields: - created_at (default) request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListEventsResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/events", method="GET", params={ "page_size": page_size, "page_token": page_token, "order_by": order_by, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListEventsResponse, construct_type( type_=ListEventsResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def list_training_step_metrics( self, finetuned_model_id: str, *, page_size: typing.Optional[int] = None, page_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListTrainingStepMetricsResponse]: """ Returns a list of metrics measured during the training of a fine-tuned model. The metrics are ordered by step number, with the most recent step first. The list can be paginated using `page_size` and `page_token` parameters. Parameters ---------- finetuned_model_id : str The parent fine-tuned model ID. page_size : typing.Optional[int] Maximum number of results to be returned by the server. If 0, defaults to 50. page_token : typing.Optional[str] Request a specific page of the list results. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListTrainingStepMetricsResponse] A successful response. """ _response = await self._client_wrapper.httpx_client.request( f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/training-step-metrics", method="GET", params={ "page_size": page_size, "page_token": page_token, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListTrainingStepMetricsResponse, construct_type( type_=ListTrainingStepMetricsResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/manually_maintained/__init__.py ================================================ # This module ensures overrides are applied early in the import process # Import overrides to trigger backwards compatibility patches from .. import overrides # noqa: F401 ================================================ FILE: src/cohere/manually_maintained/cache.py ================================================ import typing import time class CacheMixin: # A simple in-memory cache with TTL (thread safe). This is used to cache tokenizers at the moment. _cache: typing.Dict[str, typing.Tuple[typing.Optional[float], typing.Any]] = dict() def _cache_get(self, key: str) -> typing.Any: val = self._cache.get(key) if val is None: return None expiry_timestamp, value = val if expiry_timestamp is None or expiry_timestamp > time.time(): return value del self._cache[key] # remove expired cache entry def _cache_set(self, key: str, value: typing.Any, ttl: int = 60 * 60) -> None: expiry_timestamp = None if ttl is not None: expiry_timestamp = time.time() + ttl self._cache[key] = (expiry_timestamp, value) ================================================ FILE: src/cohere/manually_maintained/cohere_aws/__init__.py ================================================ from .client import Client from .error import CohereError from .mode import Mode ================================================ FILE: src/cohere/manually_maintained/cohere_aws/chat.py ================================================ from .response import CohereObject from .error import CohereError from .mode import Mode from typing import List, Optional, Generator, Dict, Any, Union from enum import Enum import json # Tools class ToolParameterDefinitionsValue(CohereObject, dict): def __init__( self, type: str, description: str, required: Optional[bool] = None, **kwargs, ) -> None: super().__init__(**kwargs) self.__dict__ = self self.type = type self.description = description if required is not None: self.required = required class Tool(CohereObject, dict): def __init__( self, name: str, description: str, parameter_definitions: Optional[Dict[str, ToolParameterDefinitionsValue]] = None, **kwargs, ) -> None: super().__init__(**kwargs) self.__dict__ = self self.name = name self.description = description if parameter_definitions is not None: self.parameter_definitions = parameter_definitions class ToolCall(CohereObject, dict): def __init__( self, name: str, parameters: Dict[str, Any], generation_id: str, **kwargs, ) -> None: super().__init__(**kwargs) self.__dict__ = self self.name = name self.parameters = parameters self.generation_id = generation_id @classmethod def from_dict(cls, tool_call_res: Dict[str, Any]) -> "ToolCall": return cls( name=tool_call_res.get("name"), parameters=tool_call_res.get("parameters"), generation_id=tool_call_res.get("generation_id"), ) @classmethod def from_list(cls, tool_calls_res: Optional[List[Dict[str, Any]]]) -> Optional[List["ToolCall"]]: if tool_calls_res is None or not isinstance(tool_calls_res, list): return None return [ToolCall.from_dict(tc) for tc in tool_calls_res] # Chat class Chat(CohereObject): def __init__( self, response_id: str, generation_id: str, text: str, chat_history: Optional[List[Dict[str, Any]]] = None, preamble: Optional[str] = None, finish_reason: Optional[str] = None, token_count: Optional[Dict[str, int]] = None, tool_calls: Optional[List[ToolCall]] = None, citations: Optional[List[Dict[str, Any]]] = None, documents: Optional[List[Dict[str, Any]]] = None, search_results: Optional[List[Dict[str, Any]]] = None, search_queries: Optional[List[Dict[str, Any]]] = None, is_search_required: Optional[bool] = None, ) -> None: self.response_id = response_id self.generation_id = generation_id self.text = text self.chat_history = chat_history self.preamble = preamble self.finish_reason = finish_reason self.token_count = token_count self.tool_calls = tool_calls self.citations = citations self.documents = documents self.search_results = search_results self.search_queries = search_queries self.is_search_required = is_search_required @classmethod def from_dict(cls, response: Dict[str, Any]) -> "Chat": return cls( response_id=response["response_id"], generation_id=response.get("generation_id"), # optional text=response.get("text"), chat_history=response.get("chat_history"), # optional preamble=response.get("preamble"), # optional token_count=response.get("token_count"), is_search_required=response.get("is_search_required"), # optional citations=response.get("citations"), # optional documents=response.get("documents"), # optional search_results=response.get("search_results"), # optional search_queries=response.get("search_queries"), # optional finish_reason=response.get("finish_reason"), tool_calls=ToolCall.from_list(response.get("tool_calls")), # optional ) # ---------------| # Steaming event | # ---------------| class StreamEvent(str, Enum): STREAM_START = "stream-start" SEARCH_QUERIES_GENERATION = "search-queries-generation" SEARCH_RESULTS = "search-results" TEXT_GENERATION = "text-generation" TOOL_CALLS_GENERATION = "tool-calls-generation" CITATION_GENERATION = "citation-generation" STREAM_END = "stream-end" class StreamResponse(CohereObject): def __init__( self, is_finished: bool, event_type: Union[StreamEvent, str], index: Optional[int], **kwargs, ) -> None: super().__init__(**kwargs) self.is_finished = is_finished self.index = index self.event_type = event_type class StreamStart(StreamResponse): def __init__( self, generation_id: str, conversation_id: Optional[str], **kwargs, ) -> None: super().__init__(**kwargs) self.generation_id = generation_id self.conversation_id = conversation_id class StreamTextGeneration(StreamResponse): def __init__( self, text: str, **kwargs, ) -> None: super().__init__(**kwargs) self.text = text class StreamCitationGeneration(StreamResponse): def __init__( self, citations: Optional[List[Dict[str, Any]]], **kwargs, ) -> None: super().__init__(**kwargs) self.citations = citations class StreamQueryGeneration(StreamResponse): def __init__( self, search_queries: Optional[List[Dict[str, Any]]], **kwargs, ) -> None: super().__init__(**kwargs) self.search_queries = search_queries class StreamSearchResults(StreamResponse): def __init__( self, search_results: Optional[List[Dict[str, Any]]], documents: Optional[List[Dict[str, Any]]], **kwargs, ) -> None: super().__init__(**kwargs) self.search_results = search_results self.documents = documents class StreamEnd(StreamResponse): def __init__( self, finish_reason: str, **kwargs, ) -> None: super().__init__(**kwargs) self.finish_reason = finish_reason class ChatToolCallsGenerationEvent(StreamResponse): def __init__( self, tool_calls: Optional[List[ToolCall]], **kwargs, ) -> None: super().__init__(**kwargs) self.tool_calls = tool_calls class StreamingChat(CohereObject): def __init__(self, stream_response, mode): self.stream_response = stream_response self.text = None self.response_id = None self.generation_id = None self.preamble = None self.prompt = None self.chat_history = None self.finish_reason = None self.token_count = None self.is_search_required = None self.citations = None self.documents = None self.search_results = None self.search_queries = None self.tool_calls = None self.bytes = bytearray() if mode == Mode.SAGEMAKER: self.payload_key = "PayloadPart" self.bytes_key = "Bytes" elif mode == Mode.BEDROCK: self.payload_key = "chunk" self.bytes_key = "bytes" def _make_response_item(self, index, streaming_item) -> Any: event_type = streaming_item.get("event_type") if event_type == StreamEvent.STREAM_START: self.conversation_id = streaming_item.get("conversation_id") self.generation_id = streaming_item.get("generation_id") return StreamStart( conversation_id=self.conversation_id, generation_id=self.generation_id, is_finished=False, event_type=event_type, index=index, ) elif event_type == StreamEvent.SEARCH_QUERIES_GENERATION: search_queries = streaming_item.get("search_queries") return StreamQueryGeneration( search_queries=search_queries, is_finished=False, event_type=event_type, index=index ) elif event_type == StreamEvent.SEARCH_RESULTS: search_results = streaming_item.get("search_results") documents = streaming_item.get("documents") return StreamSearchResults( search_results=search_results, documents=documents, is_finished=False, event_type=event_type, index=index, ) elif event_type == StreamEvent.TEXT_GENERATION: text = streaming_item.get("text") return StreamTextGeneration(text=text, is_finished=False, event_type=event_type, index=index) elif event_type == StreamEvent.CITATION_GENERATION: citations = streaming_item.get("citations") return StreamCitationGeneration(citations=citations, is_finished=False, event_type=event_type, index=index) elif event_type == StreamEvent.TOOL_CALLS_GENERATION: tool_calls = ToolCall.from_list(streaming_item.get("tool_calls")) return ChatToolCallsGenerationEvent( tool_calls=tool_calls, is_finished=False, event_type=event_type, index=index ) elif event_type == StreamEvent.STREAM_END: response = streaming_item.get("response") finish_reason = streaming_item.get("finish_reason") self.finish_reason = finish_reason if response is None: return None self.response_id = response.get("response_id") self.conversation_id = response.get("conversation_id") self.text = response.get("text") self.generation_id = response.get("generation_id") self.preamble = response.get("preamble") self.prompt = response.get("prompt") self.chat_history = response.get("chat_history") self.token_count = response.get("token_count") self.is_search_required = response.get("is_search_required") # optional self.citations = response.get("citations") # optional self.documents = response.get("documents") # optional self.search_results = response.get("search_results") # optional self.search_queries = response.get("search_queries") # optional self.tool_calls = ToolCall.from_list(response.get("tool_calls")) # optional return StreamEnd(finish_reason=finish_reason, is_finished=True, event_type=event_type, index=index) return None def __iter__(self) -> Generator[StreamResponse, None, None]: index = 0 for payload in self.stream_response: self.bytes.extend(payload[self.payload_key][self.bytes_key]) try: item = self._make_response_item(index, json.loads(self.bytes)) except json.decoder.JSONDecodeError: # payload contained only a partion JSON object continue self.bytes = bytearray() if item is not None: index += 1 yield item ================================================ FILE: src/cohere/manually_maintained/cohere_aws/classification.py ================================================ from .response import CohereObject from typing import Any, Dict, Iterator, List, Literal, Union Prediction = Union[str, int, List[str], List[int]] ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any] class Classification(CohereObject): def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None: # Prediction is the old format (version 1 of classification-finetuning) # ClassificationDict is the new format (version 2 of classification-finetuning). # It also contains the original text and the labels' confidence scores of the prediction self.classification = classification def is_multilabel(self) -> bool: if isinstance(self.classification, list): return True elif isinstance(self.classification, (int, str)): return False return isinstance(self.classification["prediction"], list) @property def prediction(self) -> Prediction: if isinstance(self.classification, (list, int, str)): return self.classification return self.classification["prediction"] @property def confidence(self) -> List[float]: if isinstance(self.classification, (list, int, str)): raise ValueError( "Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" ) return self.classification["confidence"] @property def text(self) -> str: if isinstance(self.classification, (list, int, str)): raise ValueError( "Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" ) return self.classification["text"] class Classifications(CohereObject): def __init__(self, classifications: List[Classification]) -> None: self.classifications = classifications if len(self.classifications) > 0: assert all( [c.is_multilabel() == self.is_multilabel() for c in self.classifications] ), "All classifications must be of the same type (single-label or multi-label)" def __iter__(self) -> Iterator: return iter(self.classifications) def __len__(self) -> int: return len(self.classifications) def is_multilabel(self) -> bool: return len(self.classifications) > 0 and self.classifications[0].is_multilabel() ================================================ FILE: src/cohere/manually_maintained/cohere_aws/client.py ================================================ import json import os import tarfile import tempfile import time from typing import Any, Dict, List, Optional, Union from .classification import Classification, Classifications from .embeddings import Embeddings from .error import CohereError from .generation import Generations, StreamingGenerations from .chat import Chat, StreamingChat from .rerank import Reranking from .summary import Summary from .mode import Mode import typing from ..lazy_aws_deps import lazy_boto3, lazy_botocore, lazy_sagemaker class Client: def __init__( self, aws_region: typing.Optional[str] = None, mode: Mode = Mode.SAGEMAKER, ): """ By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with `aws configure set region us-west-2` or override it with `region_name` parameter. """ self.mode = mode if os.environ.get('AWS_DEFAULT_REGION') is None: os.environ['AWS_DEFAULT_REGION'] = aws_region if self.mode == Mode.SAGEMAKER: self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) elif self.mode == Mode.BEDROCK: self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region) self._service_client = lazy_boto3().client("bedrock", region_name=aws_region) self._sess = None self._endpoint_name = None def _require_sagemaker(self) -> None: if self.mode != Mode.SAGEMAKER: raise CohereError("This method is only supported in SageMaker mode.") def _does_endpoint_exist(self, endpoint_name: str) -> bool: try: self._service_client.describe_endpoint(EndpointName=endpoint_name) except lazy_botocore().ClientError: return False return True def connect_to_endpoint(self, endpoint_name: str) -> None: """Connects to an existing SageMaker endpoint. Args: endpoint_name (str): The name of the endpoint. Raises: CohereError: Connection to the endpoint failed. """ self._require_sagemaker() if not self._does_endpoint_exist(endpoint_name): raise CohereError(f"Endpoint {endpoint_name} does not exist.") self._endpoint_name = endpoint_name def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: """ Compress an S3 folder which contains one or several fine-tuned models to a tar file. If the S3 folder contains only one fine-tuned model, it simply returns the path to that model. If the S3 folder contains several fine-tuned models, it download all models, aggregates them into a single tar.gz file. Args: s3_models_dir (str): S3 URI pointing to a folder Returns: str: S3 URI pointing to the `models.tar.gz` file """ s3_models_dir = s3_models_dir.rstrip("/") + "/" # Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz s3_tar_models = [ s3_path for s3_path in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) if ( s3_path.endswith(".tar.gz") # only .tar.gz files and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating and (s3_path.rsplit("/", 1)[0] == s3_models_dir[:-1]) # only files at the root of s3_models_dir ) ] if len(s3_tar_models) == 0: raise CohereError(f"No fine-tuned models found in {s3_models_dir}") elif len(s3_tar_models) == 1: print(f"Found one fine-tuned model: {s3_tar_models[0]}") return s3_tar_models[0] # More than one fine-tuned model found, need to aggregate them into a single .tar.gz file with tempfile.TemporaryDirectory() as tmpdir: local_tar_models_dir = os.path.join(tmpdir, "tar") local_models_dir = os.path.join(tmpdir, "models") # Download and extract all fine-tuned models for s3_tar_model in s3_tar_models: print(f"Adding fine-tuned model: {s3_tar_model}") lazy_sagemaker().s3.S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess) with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar: tar.extractall(local_models_dir) # Compress local_models_dir to a tar.gz file model_tar = os.path.join(tmpdir, "models.tar.gz") with tarfile.open(model_tar, "w:gz") as tar: tar.add(local_models_dir, arcname=".") # Upload the new tarfile containing all models to s3 # Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload model_tar_s3 = lazy_sagemaker().s3.S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess) # sanity check assert s3_models_dir + "models.tar.gz" in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) return model_tar_s3 def create_endpoint( self, arn: str, endpoint_name: str, s3_models_dir: Optional[str] = None, instance_type: str = "ml.g4dn.xlarge", n_instances: int = 1, recreate: bool = False, role: Optional[str] = None, ) -> None: """Creates and deploys a SageMaker endpoint. Args: arn (str): The product ARN. Refers to a ready-to-use model (model package) or a fine-tuned model (algorithm). endpoint_name (str): The name of the endpoint. s3_models_dir (str, optional): S3 URI pointing to the folder containing fine-tuned models. Defaults to None. instance_type (str, optional): The EC2 instance type to deploy the endpoint to. Defaults to "ml.g4dn.xlarge". n_instances (int, optional): Number of endpoint instances. Defaults to 1. recreate (bool, optional): Force re-creation of endpoint if it already exists. Defaults to False. role (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role() will be used to get the role. This should work when one uses the client inside SageMaker. If this errors out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. """ self._require_sagemaker() # First, check if endpoint already exists if self._does_endpoint_exist(endpoint_name): if recreate: self.connect_to_endpoint(endpoint_name) self.delete_endpoint() else: raise CohereError(f"Endpoint {endpoint_name} already exists and recreate={recreate}.") kwargs = {} model_data = None validation_params = dict() useBoto = False if s3_models_dir is not None: # If s3_models_dir is given, we assume to have custom fine-tuned models -> Algorithm kwargs["algorithm_arn"] = arn model_data = self._s3_models_dir_to_tarfile(s3_models_dir) else: # If no s3_models_dir is given, we assume to use a pre-trained model -> ModelPackage kwargs["model_package_arn"] = arn # For now only non-finetuned models can use these timeouts validation_params = dict( model_data_download_timeout=2400, container_startup_health_check_timeout=2400 ) useBoto = True # Out of precaution, check if there is an endpoint config and delete it if that's the case # Otherwise it might block deployment try: self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name) except lazy_botocore().ClientError: pass try: self._service_client.delete_model(ModelName=endpoint_name) except lazy_botocore().ClientError: pass if role is None: if useBoto: accountID = lazy_sagemaker().account_id() role = f"arn:aws:iam::{accountID}:role/ServiceRoleSagemaker" else: try: role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" # deploy fine-tuned model using sagemaker SDK if s3_models_dir is not None: model = lazy_sagemaker().ModelPackage( role=role, model_data=model_data, sagemaker_session=self._sess, # makes sure the right region is used **kwargs ) try: model.deploy( n_instances, instance_type, endpoint_name=endpoint_name, **validation_params ) except lazy_botocore().ParamValidationError: # For at least some versions of python 3.6, SageMaker SDK does not support the validation_params model.deploy(n_instances, instance_type, endpoint_name=endpoint_name) else: # deploy pre-trained model using boto to add InferenceAmiVersion self._service_client.create_model( ModelName=endpoint_name, ExecutionRoleArn=role, EnableNetworkIsolation=True, PrimaryContainer={ 'ModelPackageName': arn, }, ) self._service_client.create_endpoint_config( EndpointConfigName=endpoint_name, ProductionVariants=[ { 'VariantName': 'AllTraffic', 'ModelName': endpoint_name, 'InstanceType': instance_type, 'InitialInstanceCount': n_instances, 'InferenceAmiVersion': 'al2-ami-sagemaker-inference-gpu-2' }, ], ) self._service_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_name, ) waiter = self._service_client.get_waiter('endpoint_in_service') try: print(f"Waiting for endpoint {endpoint_name} to be in service...") waiter.wait( EndpointName=endpoint_name, WaiterConfig={ 'Delay': 30, 'MaxAttempts': 80 } ) except Exception as e: raise CohereError(f"Failed to create endpoint: {e}") self.connect_to_endpoint(endpoint_name) def chat( self, message: str, stream: Optional[bool] = False, preamble: Optional[str] = None, chat_history: Optional[List[Dict[str, Any]]] = None, # should only be passed for stacked finetune deployment model: Optional[str] = None, # should only be passed for Bedrock mode; ignored otherwise model_id: Optional[str] = None, temperature: Optional[float] = None, p: Optional[float] = None, k: Optional[float] = None, max_tokens: Optional[int] = None, search_queries_only: Optional[bool] = None, documents: Optional[List[Dict[str, Any]]] = None, prompt_truncation: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_results: Optional[List[Dict[str, Any]]] = None, raw_prompting: Optional[bool] = False, return_prompt: Optional[bool] = False, variant: Optional[str] = None, ) -> Union[Chat, StreamingChat]: """Returns a Chat object with the query reply. Args: message (str): The message to send to the chatbot. stream (bool): Return streaming tokens. preamble (str): (Optional) A string to override the preamble. chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state. model (str): (Optional) The model to use for generating the response. Should only be passed for stacked finetune deployment. model_id (str): (Optional) The model to use for generating the response. Should only be passed for Bedrock mode; ignored otherwise. temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response. p (float): (Optional) The nucleus sampling probability. k (float): (Optional) The top-k sampling probability. max_tokens (int): (Optional) The max tokens generated for the next reply. search_queries_only (bool): (Optional) When true, the response will only contain a list of generated `search_queries`, no reply from the model to the user's message will be generated. documents (List[Dict[str, str]]): (Optional) Documents to use to generate grounded response with citations. Example: documents=[ { "id": "national_geographic_everest", "title": "Height of Mount Everest", "snippet": "The height of Mount Everest is 29,035 feet", "url": "https://education.nationalgeographic.org/resource/mount-everest/", }, { "id": "national_geographic_mariana", "title": "Depth of the Mariana Trench", "snippet": "The depth of the Mariana Trench is 36,070 feet", "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth", }, ], prompt_truncation (str) (Optional): Defaults to `OFF`. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be raised. Returns: a Chat object if stream=False, or a StreamingChat object if stream=True Examples: A simple chat message: >>> res = co.chat(message="Hey! How are you doing today?") >>> print(res.text) Streaming chat: >>> res = co.chat( >>> message="Hey! How are you doing today?", >>> stream=True) >>> for token in res: >>> print(token) Stateless chat with chat history: >>> res = co.chat( >>> chat_history=[ >>> {'role': 'User', message': 'Hey! How are you doing today?'}, >>> {'role': 'Chatbot', message': 'I am doing great! How can I help you?'}, >>> message="Tell me a joke!", >>> ]) >>> print(res.text) Chat message with documents to use to generate the response: >>> res = co.chat( >>> "How deep in the Mariana Trench", >>> documents=[ >>> { >>> "id": "national_geographic_everest", >>> "title": "Height of Mount Everest", >>> "snippet": "The height of Mount Everest is 29,035 feet", >>> "url": "https://education.nationalgeographic.org/resource/mount-everest/", >>> }, >>> { >>> "id": "national_geographic_mariana", >>> "title": "Depth of the Mariana Trench", >>> "snippet": "The depth of the Mariana Trench is 36,070 feet", >>> "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth", >>> }, >>> ]) >>> print(res.text) >>> print(res.citations) >>> print(res.documents) Generate search queries for fetching documents to use in chat: >>> res = co.chat( >>> "What is the height of Mount Everest?", >>> search_queries_only=True) >>> if res.is_search_required: >>> print(res.search_queries) """ if self.mode == Mode.SAGEMAKER and self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") json_params = { "model": model, "message": message, "chat_history": chat_history, "preamble": preamble, "temperature": temperature, "max_tokens": max_tokens, "stream": stream, "p": p, "k": k, "tools": tools, "tool_results": tool_results, "search_queries_only": search_queries_only, "documents": documents, "raw_prompting": raw_prompting, "return_prompt": return_prompt, "prompt_truncation": prompt_truncation } for key, value in list(json_params.items()): if value is None: del json_params[key] if self.mode == Mode.SAGEMAKER: return self._sagemaker_chat(json_params, variant) elif self.mode == Mode.BEDROCK: return self._bedrock_chat(json_params, model_id) else: raise CohereError("Unsupported mode") def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) : json_body = json.dumps(json_params) params = { 'EndpointName': self._endpoint_name, 'ContentType': 'application/json', 'Body': json_body, } if variant: params['TargetVariant'] = variant try: if json_params['stream']: result = self._client.invoke_endpoint_with_response_stream( **params) return StreamingChat(result['Body'], self.mode) else: result = self._client.invoke_endpoint(**params) return Chat.from_dict(json.loads(result['Body'].read().decode())) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) : if not model_id: raise CohereError("must supply model_id arg when calling bedrock") if json_params['stream']: stream = json_params['stream'] else: stream = False # Bedrock does not expect the stream key to be present in the body, use invoke_model_with_response_stream to indicate stream mode del json_params['stream'] json_body = json.dumps(json_params) params = { 'body': json_body, 'modelId': model_id, } try: if stream: result = self._client.invoke_model_with_response_stream( **params) return StreamingChat(result['body'], self.mode) else: result = self._client.invoke_model(**params) return Chat.from_dict( json.loads(result['body'].read().decode())) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) def generate( self, prompt: str, # should only be passed for stacked finetune deployment model: Optional[str] = None, # should only be passed for Bedrock mode; ignored otherwise model_id: Optional[str] = None, # requires DB with presets # preset: str = None, num_generations: int = 1, max_tokens: int = 400, temperature: float = 1.0, k: int = 0, p: float = 0.75, stop_sequences: Optional[List[str]] = None, return_likelihoods: Optional[str] = None, truncate: Optional[str] = None, variant: Optional[str] = None, stream: Optional[bool] = True, ) -> Union[Generations, StreamingGenerations]: if self.mode == Mode.SAGEMAKER and self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") json_params = { 'model': model, 'prompt': prompt, 'max_tokens': max_tokens, 'temperature': temperature, 'k': k, 'p': p, 'stop_sequences': stop_sequences, 'return_likelihoods': return_likelihoods, 'truncate': truncate, 'stream': stream, } for key, value in list(json_params.items()): if value is None: del json_params[key] if self.mode == Mode.SAGEMAKER: # TODO: Bedrock should support this param too json_params['num_generations'] = num_generations return self._sagemaker_generations(json_params, variant) elif self.mode == Mode.BEDROCK: return self._bedrock_generations(json_params, model_id) else: raise CohereError("Unsupported mode") def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) : json_body = json.dumps(json_params) params = { 'EndpointName': self._endpoint_name, 'ContentType': 'application/json', 'Body': json_body, } if variant: params['TargetVariant'] = variant try: if json_params['stream']: result = self._client.invoke_endpoint_with_response_stream( **params) return StreamingGenerations(result['Body'], self.mode) else: result = self._client.invoke_endpoint(**params) return Generations( json.loads(result['Body'].read().decode())['generations']) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) : if not model_id: raise CohereError("must supply model_id arg when calling bedrock") json_body = json.dumps(json_params) params = { 'body': json_body, 'modelId': model_id, } try: if json_params['stream']: result = self._client.invoke_model_with_response_stream( **params) return StreamingGenerations(result['body'], self.mode) else: result = self._client.invoke_model(**params) return Generations( json.loads(result['body'].read().decode())['generations']) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) def embed( self, texts: List[str], truncate: Optional[str] = None, variant: Optional[str] = None, input_type: Optional[str] = None, model_id: Optional[str] = None, output_dimension: Optional[int] = None, embedding_types: Optional[List[str]] = None, ) -> Union[Embeddings, Dict[str, List]]: json_params = { 'texts': texts, 'truncate': truncate, "input_type": input_type, "output_dimension": output_dimension, "embedding_types": embedding_types, } for key, value in list(json_params.items()): if value is None: del json_params[key] if self.mode == Mode.SAGEMAKER: return self._sagemaker_embed(json_params, variant) elif self.mode == Mode.BEDROCK: return self._bedrock_embed(json_params, model_id) else: raise CohereError("Unsupported mode") def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str): if self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") json_body = json.dumps(json_params) params = { 'EndpointName': self._endpoint_name, 'ContentType': 'application/json', 'Body': json_body, } if variant: params['TargetVariant'] = variant try: result = self._client.invoke_endpoint(**params) response = json.loads(result['Body'].read().decode()) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) embeddings = response['embeddings'] if isinstance(embeddings, dict): return embeddings return Embeddings(embeddings) def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str): if not model_id: raise CohereError("must supply model_id arg when calling bedrock") json_body = json.dumps(json_params) params = { 'body': json_body, 'modelId': model_id, } try: result = self._client.invoke_model(**params) response = json.loads(result['body'].read().decode()) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) embeddings = response['embeddings'] if isinstance(embeddings, dict): return embeddings return Embeddings(embeddings) def rerank(self, query: str, documents: Union[List[str], List[Dict[str, Any]]], top_n: Optional[int] = None, variant: Optional[str] = None, max_chunks_per_doc: Optional[int] = None, rank_fields: Optional[List[str]] = None) -> Reranking: """Returns an ordered list of documents oridered by their relevance to the provided query Args: query (str): The search query documents (list[str], list[dict]): The documents to rerank top_n (int): (optional) The number of results to return, defaults to return all results max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document rank_fields (list[str]): (optional) The fields used for reranking. This parameter is only supported for rerank v3 models """ if self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") parsed_docs = [] for doc in documents: if isinstance(doc, str): parsed_docs.append({'text': doc}) elif isinstance(doc, dict): parsed_docs.append(doc) else: raise CohereError( message='invalid format for documents, must be a list of strings or dicts') json_params = { "query": query, "documents": parsed_docs, "top_n": top_n, "return_documents": False, "max_chunks_per_doc" : max_chunks_per_doc, "rank_fields": rank_fields } json_body = json.dumps(json_params) params = { 'EndpointName': self._endpoint_name, 'ContentType': 'application/json', 'Body': json_body, } if variant is not None: params['TargetVariant'] = variant try: result = self._client.invoke_endpoint(**params) response = json.loads(result['Body'].read().decode()) reranking = Reranking(response) for rank in reranking.results: rank.document = parsed_docs[rank.index] except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) return reranking def classify(self, input: List[str], name: str) -> Classifications: if self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") json_params = {"texts": input, "model_id": name} json_body = json.dumps(json_params) params = { "EndpointName": self._endpoint_name, "ContentType": "application/json", "Body": json_body, } try: result = self._client.invoke_endpoint(**params) response = json.loads(result["Body"].read().decode()) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) return Classifications([Classification(classification) for classification in response]) def create_finetune( self, name: str, train_data: str, s3_models_dir: str, arn: Optional[str] = None, eval_data: Optional[str] = None, instance_type: str = "ml.g4dn.xlarge", training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters role: Optional[str] = None, base_model_id: Optional[str] = None, ) -> Optional[str]: """Creates a fine-tuning job and returns an optional fintune job ID. Args: name (str): The name to give to the fine-tuned model. train_data (str): An S3 path pointing to the training data. s3_models_dir (str): An S3 path pointing to the directory where the fine-tuned model will be saved. arn (str, optional): The product ARN of the fine-tuning package. Required in Sagemaker mode and ignored otherwise eval_data (str, optional): An S3 path pointing to the eval data. Defaults to None. instance_type (str, optional): The EC2 instance type to use for training. Defaults to "ml.g4dn.xlarge". training_parameters (Dict[str, Any], optional): Additional training parameters. Defaults to {}. role (str, optional): The IAM role to use for the endpoint. In Bedrock this mode is required and is used to access s3 input and output data. If not provided in sagemaker, sagemaker.get_execution_role()will be used to get the role. This should work when one uses the client inside SageMaker. If this errors out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. base_model_id (str, optional): The ID of the Bedrock base model to finetune with. Required in Bedrock mode and ignored otherwise. """ assert name != "model", "name cannot be 'model'" if self.mode == Mode.BEDROCK: return self._bedrock_create_finetune(name=name, train_data=train_data, s3_models_dir=s3_models_dir, base_model=base_model_id, eval_data=eval_data, training_parameters=training_parameters, role=role) s3_models_dir = s3_models_dir.rstrip("/") + "/" if role is None: try: role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" training_parameters.update({"name": name}) estimator = lazy_sagemaker().algorithm.AlgorithmEstimator( algorithm_arn=arn, role=role, instance_count=1, instance_type=instance_type, sagemaker_session=self._sess, output_path=s3_models_dir, hyperparameters=training_parameters, ) inputs = {} if not train_data.startswith("s3:"): raise ValueError("train_data must point to an S3 location.") inputs["training"] = train_data if eval_data is not None: if not eval_data.startswith("s3:"): raise ValueError("eval_data must point to an S3 location.") inputs["evaluation"] = eval_data estimator.fit(inputs=inputs) job_name = estimator.latest_training_job.name current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz" s3_resource = lazy_boto3().resource("s3") # Copy new model to root of output_model_dir bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath) _, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_models_dir}{name}.tar.gz") s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key}) # Delete old dir bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(s3_models_dir + job_name) s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete() def export_finetune( self, name: str, s3_checkpoint_dir: str, s3_output_dir: str, arn: str, instance_type: str = "ml.p4de.24xlarge", role: Optional[str] = None, ) -> None: """Export the merged weights to the TensorRT-LLM inference engine. Args: name (str): The name used while writing the exported model to the output directory. s3_checkpoint_dir (str): An S3 path pointing to the directory of the model checkpoint (merged weights). s3_output_dir (str): An S3 path pointing to the directory where the TensorRT-LLM engine will be saved. arn (str): The product ARN of the bring your own finetuning algorithm. instance_type (str, optional): The EC2 instance type to use for export. Defaults to "ml.p4de.24xlarge". role (str, optional): The IAM role to use for export. If not provided, sagemaker.get_execution_role() will be used to get the role. This should work when one uses the client inside SageMaker. If this errors out, the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker. """ self._require_sagemaker() if name == "model": raise ValueError("name cannot be 'model'") s3_output_dir = s3_output_dir.rstrip("/") + "/" if role is None: try: role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" export_parameters = {"name": name} estimator = lazy_sagemaker().algorithm.AlgorithmEstimator( algorithm_arn=arn, role=role, instance_count=1, instance_type=instance_type, sagemaker_session=self._sess, output_path=s3_output_dir, hyperparameters=export_parameters, ) if not s3_checkpoint_dir.startswith("s3:"): raise ValueError("s3_checkpoint_dir must point to an S3 location.") inputs = {"checkpoint": s3_checkpoint_dir} estimator.fit(inputs=inputs) job_name = estimator.latest_training_job.name current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz" s3_resource = lazy_boto3().resource("s3") # Copy the exported TensorRT-LLM engine to the root of s3_output_dir bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath) _, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{name}.tar.gz") s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key}) # Delete the old S3 directory bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{job_name}") s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete() def wait_for_finetune_job(self, job_id: str, timeout: int = 2*60*60) -> str: """Waits for a finetune job to complete and returns a model arn if complete. Throws an exception if timeout occurs or if job does not complete successfully Args: job_id (str): The arn of the model customization job timeout(int, optional): Timeout in seconds """ end = time.time() + timeout while True: customization_job = self._service_client.get_model_customization_job(jobIdentifier=job_id) job_status = customization_job["status"] if job_status in ["Completed", "Failed", "Stopped"]: break if time.time() > end: raise CohereError("could not complete finetune within timeout") time.sleep(10) if job_status != "Completed": raise CohereError(f"finetune did not finish successfuly, ended with {job_status} status") return customization_job["outputModelArn"] def provision_throughput( self, model_id: str, name: str, model_units: int, commitment_duration: Optional[str] = None ) -> str: """Returns the provisined model arn Args: model_id (str): The ID or ARN of the model to provision name (str): Name of the provisioned throughput model model_units (int): Number of units to provision commitment_duration (str, optional): Commitment duration, one of ("OneMonth", "SixMonths"), defaults to no commitment if unspecified """ if self.mode != Mode.BEDROCK: raise ValueError("can only provision throughput in bedrock") kwargs = {} if commitment_duration: kwargs["commitmentDuration"] = commitment_duration response = self._service_client.create_provisioned_model_throughput( provisionedModelName=name, modelId=model_id, modelUnits=model_units, **kwargs ) return response["provisionedModelArn"] def _bedrock_create_finetune( self, name: str, train_data: str, s3_models_dir: str, base_model: str, eval_data: Optional[str] = None, training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters role: Optional[str] = None, ) -> None: if not name: raise ValueError("name must not be empty") if not role: raise ValueError("must provide a role ARN for bedrock finetuning (https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-iam-role.html)") if not train_data.startswith("s3:"): raise ValueError("train_data must point to an S3 location.") if eval_data: if not eval_data.startswith("s3:"): raise ValueError("eval_data must point to an S3 location.") validationDataConfig = { "validators": [{ "s3Uri": eval_data }] } job_name = f"{name}-job" customization_job = self._service_client.create_model_customization_job( jobName=job_name, customModelName=name, roleArn=role, baseModelIdentifier=base_model, trainingDataConfig={"s3Uri": train_data}, validationDataConfig=validationDataConfig, outputDataConfig={"s3Uri": s3_models_dir}, hyperParameters=training_parameters ) return customization_job["jobArn"] def summarize( self, text: str, length: Optional[str] = "auto", format_: Optional[str] = "auto", # Only summarize-xlarge is supported on Sagemaker # model: Optional[str] = "summarize-xlarge", extractiveness: Optional[str] = "auto", temperature: Optional[float] = 0.3, additional_command: Optional[str] = "", variant: Optional[str] = None ) -> Summary: self._require_sagemaker() if self._endpoint_name is None: raise CohereError("No endpoint connected. " "Run connect_to_endpoint() first.") json_params = { 'text': text, 'length': length, 'format': format_, 'extractiveness': extractiveness, 'temperature': temperature, 'additional_command': additional_command, } for key, value in list(json_params.items()): if value is None: del json_params[key] json_body = json.dumps(json_params) params = { 'EndpointName': self._endpoint_name, 'ContentType': 'application/json', 'Body': json_body, } if variant is not None: params['TargetVariant'] = variant try: result = self._client.invoke_endpoint(**params) response = json.loads(result['Body'].read().decode()) summary = Summary(response) except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? # ValidationError, e.g. when variant is bad raise CohereError(str(e)) return summary def delete_endpoint(self) -> None: self._require_sagemaker() if self._endpoint_name is None: raise CohereError("No endpoint connected.") try: self._service_client.delete_endpoint(EndpointName=self._endpoint_name) except: print("Endpoint not found, skipping deletion.") try: self._service_client.delete_endpoint_config(EndpointConfigName=self._endpoint_name) except: print("Endpoint config not found, skipping deletion.") def close(self) -> None: try: self._client.close() self._service_client.close() except AttributeError: print("SageMaker client could not be closed. This might be because you are using an old version of SageMaker.") raise ================================================ FILE: src/cohere/manually_maintained/cohere_aws/embeddings.py ================================================ from .response import CohereObject from typing import Iterator, List class Embedding(CohereObject): def __init__(self, embedding: List[float]) -> None: self.embedding = embedding def __iter__(self) -> Iterator: return iter(self.embedding) def __len__(self) -> int: return len(self.embedding) class Embeddings(CohereObject): def __init__(self, embeddings: List[Embedding]) -> None: self.embeddings = embeddings def __iter__(self) -> Iterator: return iter(self.embeddings) def __len__(self) -> int: return len(self.embeddings) ================================================ FILE: src/cohere/manually_maintained/cohere_aws/error.py ================================================ class CohereError(Exception): def __init__( self, message=None, http_status=None, headers=None, ) -> None: super(CohereError, self).__init__(message) self.message = message self.http_status = http_status self.headers = headers or {} def __str__(self) -> str: msg = self.message or '' return msg def __repr__(self) -> str: return '%s(message=%r, http_status=%r)' % ( self.__class__.__name__, self.message, self.http_status, ) ================================================ FILE: src/cohere/manually_maintained/cohere_aws/generation.py ================================================ from .response import CohereObject from .mode import Mode from typing import List, Optional, NamedTuple, Generator, Dict, Any import json class TokenLikelihood(CohereObject): def __init__(self, token: str, likelihood: float) -> None: self.token = token self.likelihood = likelihood class Generation(CohereObject): def __init__(self, text: str, token_likelihoods: List[TokenLikelihood]) -> None: self.text = text self.token_likelihoods = token_likelihoods class Generations(CohereObject): def __init__(self, generations: List[Generation]) -> None: self.generations = generations self.iterator = iter(generations) @classmethod def from_dict(cls, response: Dict[str, Any]) -> List[Generation]: generations: List[Generation] = [] for gen in response['generations']: token_likelihoods = None if 'token_likelihoods' in gen: token_likelihoods = [] for likelihoods in gen['token_likelihoods']: if 'likelihood' in likelihoods: token_likelihood = likelihoods['likelihood'] else: token_likelihood = None token_likelihoods.append(TokenLikelihood( likelihoods['token'], token_likelihood)) generations.append(Generation(gen['text'], token_likelihoods)) return cls(generations) def __iter__(self) -> iter: return self.iterator def __next__(self) -> next: return next(self.iterator) StreamingText = NamedTuple("StreamingText", [("index", Optional[int]), ("text", str), ("is_finished", bool)]) class StreamingGenerations(CohereObject): def __init__(self, stream, mode): self.stream = stream self.id = None self.generations = None self.finish_reason = None self.bytes = bytearray() if mode == Mode.SAGEMAKER: self.payload_key = "PayloadPart" self.bytes_key = "Bytes" elif mode == Mode.BEDROCK: self.payload_key = "chunk" self.bytes_key = "bytes" else: raise CohereError("Unsupported mode") def _make_response_item(self, streaming_item) -> Optional[StreamingText]: is_finished = streaming_item.get("is_finished") if not is_finished: index = streaming_item.get("index", 0) text = streaming_item.get("text") if text is None: return None return StreamingText( text=text, is_finished=is_finished, index=index) self.finish_reason = streaming_item.get("finish_reason") generation_response = streaming_item.get("response") if generation_response is None: return None self.id = generation_response.get("id") self.generations = Generations.from_dict(generation_response) return None def __iter__(self) -> Generator[StreamingText, None, None]: for payload in self.stream: self.bytes.extend(payload[self.payload_key][self.bytes_key]) try: item = self._make_response_item(json.loads(self.bytes)) except json.decoder.JSONDecodeError: # payload contained only a partion JSON object continue self.bytes = bytearray() if item is not None: yield item ================================================ FILE: src/cohere/manually_maintained/cohere_aws/mode.py ================================================ from enum import Enum class Mode(Enum): SAGEMAKER = 1 BEDROCK = 2 ================================================ FILE: src/cohere/manually_maintained/cohere_aws/rerank.py ================================================ from typing import Any, Dict, Iterator, List, NamedTuple, Optional from .response import CohereObject RerankDocument = NamedTuple("Document", [("text", str)]) RerankDocument.__doc__ = """ Returned by co.rerank, dict which always contains text but can also contain aribitrary fields """ class RerankResult(CohereObject): def __init__(self, document: Dict[str, Any] = None, index: int = None, relevance_score: float = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.document = document self.index = index self.relevance_score = relevance_score def __repr__(self) -> str: score = self.relevance_score index = self.index if self.document is None: return f"RerankResult" elif 'text' in self.document: text = self.document['text'] return f"RerankResult" else: return f"RerankResult" class Reranking(CohereObject): def __init__(self, response: Optional[Dict[str, Any]] = None, **kwargs) -> None: super().__init__(**kwargs) assert response is not None self.results = self._results(response) def _results(self, response: Dict[str, Any]) -> List[RerankResult]: results = [] for res in response['results']: if 'document' in res.keys(): results.append( RerankResult(res['document'], res['index'], res['relevance_score'])) else: results.append( RerankResult(index=res['index'], relevance_score=res['relevance_score'])) return results def __str__(self) -> str: return str(self.results) def __repr__(self) -> str: return self.results.__repr__() def __iter__(self) -> Iterator: return iter(self.results) def __getitem__(self, index) -> RerankResult: return self.results[index] ================================================ FILE: src/cohere/manually_maintained/cohere_aws/response.py ================================================ class CohereObject(): def __repr__(self) -> str: contents = '' exclude_list = ['iterator'] for k in self.__dict__.keys(): if k not in exclude_list: contents += f'\t{k}: {self.__dict__[k]}\n' output = f'cohere.{type(self).__name__} {{\n{contents}}}' return output ================================================ FILE: src/cohere/manually_maintained/cohere_aws/summary.py ================================================ from .error import CohereError from .response import CohereObject from typing import Any, Dict, Optional class Summary(CohereObject): def __init__(self, response: Optional[Dict[str, Any]] = None) -> None: assert response is not None if not response["summary"]: raise CohereError("Response lacks a summary") self.result = response["summary"] def __str__(self) -> str: return self.result ================================================ FILE: src/cohere/manually_maintained/lazy_aws_deps.py ================================================ warning = "AWS dependencies are not installed. Please install boto3, botocore, and sagemaker." def lazy_sagemaker(): try: import sagemaker as sage # type: ignore return sage except ImportError: raise ImportError(warning) def lazy_boto3(): try: import boto3 # type: ignore return boto3 except ImportError: raise ImportError(warning) def lazy_botocore(): try: import botocore # type: ignore return botocore except ImportError: raise ImportError(warning) ================================================ FILE: src/cohere/manually_maintained/lazy_oci_deps.py ================================================ """Lazy loading for optional OCI SDK dependency.""" from typing import Any OCI_INSTALLATION_MESSAGE = """ The OCI SDK is required to use OciClient or OciClientV2. Install it with: pip install oci Or with the optional dependency group: pip install cohere[oci] """ def lazy_oci() -> Any: """ Lazily import the OCI SDK. Returns: The oci module Raises: ImportError: If the OCI SDK is not installed """ try: import oci # type: ignore[import-untyped, import-not-found] return oci except ImportError: raise ImportError(OCI_INSTALLATION_MESSAGE) ================================================ FILE: src/cohere/manually_maintained/streaming_embed.py ================================================ """Utilities for streaming embed responses without loading all embeddings into memory.""" from __future__ import annotations from dataclasses import dataclass from typing import Iterator, List, Optional, Union @dataclass class StreamedEmbedding: """A single embedding yielded incrementally from embed_stream().""" index: int embedding: Union[List[float], List[int]] embedding_type: str text: Optional[str] = None def extract_embeddings_from_response( response_data: dict, batch_texts: List[str], global_offset: int = 0, ) -> Iterator[StreamedEmbedding]: """ Extract individual embeddings from a Cohere embed response dict. Works for both V1 (embeddings_floats / embeddings_by_type) and V2 response formats. Args: response_data: Parsed JSON response from embed endpoint batch_texts: The texts that were embedded in this batch global_offset: Starting index for this batch within the full dataset Yields: StreamedEmbedding objects """ response_type = response_data.get("response_type", "") if response_type == "embeddings_floats": embeddings = response_data.get("embeddings", []) for i, embedding in enumerate(embeddings): yield StreamedEmbedding( index=global_offset + i, embedding=embedding, embedding_type="float", text=batch_texts[i] if i < len(batch_texts) else None, ) elif response_type == "embeddings_by_type": embeddings_obj = response_data.get("embeddings", {}) for emb_type, embeddings_list in embeddings_obj.items(): type_name = emb_type.rstrip("_") if isinstance(embeddings_list, list): for i, embedding in enumerate(embeddings_list): yield StreamedEmbedding( index=global_offset + i, embedding=embedding, embedding_type=type_name, text=batch_texts[i] if i < len(batch_texts) else None, ) else: # V2 format: embeddings is a dict with type keys directly embeddings_obj = response_data.get("embeddings", {}) if isinstance(embeddings_obj, dict): for emb_type, embeddings_list in embeddings_obj.items(): type_name = emb_type.rstrip("_") if isinstance(embeddings_list, list): for i, embedding in enumerate(embeddings_list): yield StreamedEmbedding( index=global_offset + i, embedding=embedding, embedding_type=type_name, text=batch_texts[i] if i < len(batch_texts) else None, ) ================================================ FILE: src/cohere/manually_maintained/tokenizers.py ================================================ import asyncio import logging import typing import requests from tokenizers import Tokenizer # type: ignore if typing.TYPE_CHECKING: from cohere.client import AsyncClient, Client TOKENIZER_CACHE_KEY = "tokenizers" logger = logging.getLogger(__name__) def tokenizer_cache_key(model: str) -> str: return f"{TOKENIZER_CACHE_KEY}:{model}" def get_hf_tokenizer(co: "Client", model: str) -> Tokenizer: """Returns a HF tokenizer from a given tokenizer config URL.""" tokenizer = co._cache_get(tokenizer_cache_key(model)) if tokenizer is not None: return tokenizer tokenizer_url = co.models.get(model).tokenizer_url if not tokenizer_url: raise ValueError(f"No tokenizer URL found for model {model}") # Print the size of the tokenizer config before downloading it. try: size = _get_tokenizer_config_size(tokenizer_url) logger.info(f"Downloading tokenizer for model {model}. Size is {size} MBs.") except Exception as e: # Skip the size logging, this is not critical. logger.warn(f"Failed to get the size of the tokenizer config: {e}") response = requests.get(tokenizer_url) tokenizer = Tokenizer.from_str(response.text) co._cache_set(tokenizer_cache_key(model), tokenizer) return tokenizer def local_tokenize(co: "Client", model: str, text: str) -> typing.List[int]: """Encodes a given text using a local tokenizer.""" tokenizer = get_hf_tokenizer(co, model) return tokenizer.encode(text, add_special_tokens=False).ids def local_detokenize(co: "Client", model: str, tokens: typing.Sequence[int]) -> str: """Decodes a given list of tokens using a local tokenizer.""" tokenizer = get_hf_tokenizer(co, model) return tokenizer.decode(tokens) async def async_get_hf_tokenizer(co: "AsyncClient", model: str) -> Tokenizer: """Returns a HF tokenizer from a given tokenizer config URL.""" tokenizer = co._cache_get(tokenizer_cache_key(model)) if tokenizer is not None: return tokenizer tokenizer_url = (await co.models.get(model)).tokenizer_url if not tokenizer_url: raise ValueError(f"No tokenizer URL found for model {model}") # Print the size of the tokenizer config before downloading it. try: size = _get_tokenizer_config_size(tokenizer_url) logger.info(f"Downloading tokenizer for model {model}. Size is {size} MBs.") except Exception as e: # Skip the size logging, this is not critical. logger.warn(f"Failed to get the size of the tokenizer config: {e}") response = await asyncio.get_event_loop().run_in_executor(None, requests.get, tokenizer_url) tokenizer = Tokenizer.from_str(response.text) co._cache_set(tokenizer_cache_key(model), tokenizer) return tokenizer async def async_local_tokenize(co: "AsyncClient", model: str, text: str) -> typing.List[int]: """Encodes a given text using a local tokenizer.""" tokenizer = await async_get_hf_tokenizer(co, model) return tokenizer.encode(text, add_special_tokens=False).ids async def async_local_detokenize(co: "AsyncClient", model: str, tokens: typing.Sequence[int]) -> str: """Decodes a given list of tokens using a local tokenizer.""" tokenizer = await async_get_hf_tokenizer(co, model) return tokenizer.decode(tokens) def _get_tokenizer_config_size(tokenizer_url: str) -> float: # Get the size of the tokenizer config before downloading it. # Content-Length is not always present in the headers (if transfer-encoding: chunked). head_response = requests.head(tokenizer_url) size = None for header in ["x-goog-stored-content-length", "Content-Length"]: size = head_response.headers.get(header) if size: break return round(int(typing.cast(int, size)) / 1024 / 1024, 2) ================================================ FILE: src/cohere/models/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file ================================================ FILE: src/cohere/models/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from ..types.compatible_endpoint import CompatibleEndpoint from ..types.get_model_response import GetModelResponse from ..types.list_models_response import ListModelsResponse from .raw_client import AsyncRawModelsClient, RawModelsClient class ModelsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawModelsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawModelsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawModelsClient """ return self._raw_client def get(self, model: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetModelResponse: """ Returns the details of a model, provided its name. Parameters ---------- model : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetModelResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.models.get( model="command-a-03-2025", ) """ _response = self._raw_client.get(model, request_options=request_options) return _response.data def list( self, *, page_size: typing.Optional[float] = None, page_token: typing.Optional[str] = None, endpoint: typing.Optional[CompatibleEndpoint] = None, default_only: typing.Optional[bool] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListModelsResponse: """ Returns a list of models available for use. Parameters ---------- page_size : typing.Optional[float] Maximum number of models to include in a page Defaults to `20`, min value of `1`, max value of `1000`. page_token : typing.Optional[str] Page token provided in the `next_page_token` field of a previous response. endpoint : typing.Optional[CompatibleEndpoint] When provided, filters the list of models to only those that are compatible with the specified endpoint. default_only : typing.Optional[bool] When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListModelsResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.models.list( page_size=1.1, page_token="page_token", endpoint="chat", default_only=True, ) """ _response = self._raw_client.list( page_size=page_size, page_token=page_token, endpoint=endpoint, default_only=default_only, request_options=request_options, ) return _response.data class AsyncModelsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawModelsClient(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawModelsClient: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawModelsClient """ return self._raw_client async def get(self, model: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetModelResponse: """ Returns the details of a model, provided its name. Parameters ---------- model : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- GetModelResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.models.get( model="command-a-03-2025", ) asyncio.run(main()) """ _response = await self._raw_client.get(model, request_options=request_options) return _response.data async def list( self, *, page_size: typing.Optional[float] = None, page_token: typing.Optional[str] = None, endpoint: typing.Optional[CompatibleEndpoint] = None, default_only: typing.Optional[bool] = None, request_options: typing.Optional[RequestOptions] = None, ) -> ListModelsResponse: """ Returns a list of models available for use. Parameters ---------- page_size : typing.Optional[float] Maximum number of models to include in a page Defaults to `20`, min value of `1`, max value of `1000`. page_token : typing.Optional[str] Page token provided in the `next_page_token` field of a previous response. endpoint : typing.Optional[CompatibleEndpoint] When provided, filters the list of models to only those that are compatible with the specified endpoint. default_only : typing.Optional[bool] When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- ListModelsResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.models.list( page_size=1.1, page_token="page_token", endpoint="chat", default_only=True, ) asyncio.run(main()) """ _response = await self._raw_client.list( page_size=page_size, page_token=page_token, endpoint=endpoint, default_only=default_only, request_options=request_options, ) return _response.data ================================================ FILE: src/cohere/models/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from json.decoder import JSONDecodeError from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.jsonable_encoder import jsonable_encoder from ..core.parse_error import ParsingError from ..core.request_options import RequestOptions from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.client_closed_request_error import ClientClosedRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.gateway_timeout_error import GatewayTimeoutError from ..errors.internal_server_error import InternalServerError from ..errors.invalid_token_error import InvalidTokenError from ..errors.not_found_error import NotFoundError from ..errors.not_implemented_error import NotImplementedError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.too_many_requests_error import TooManyRequestsError from ..errors.unauthorized_error import UnauthorizedError from ..errors.unprocessable_entity_error import UnprocessableEntityError from ..types.compatible_endpoint import CompatibleEndpoint from ..types.get_model_response import GetModelResponse from ..types.list_models_response import ListModelsResponse from pydantic import ValidationError class RawModelsClient: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper def get( self, model: str, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[GetModelResponse]: """ Returns the details of a model, provided its name. Parameters ---------- model : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[GetModelResponse] OK """ _response = self._client_wrapper.httpx_client.request( f"v1/models/{jsonable_encoder(model)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetModelResponse, construct_type( type_=GetModelResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def list( self, *, page_size: typing.Optional[float] = None, page_token: typing.Optional[str] = None, endpoint: typing.Optional[CompatibleEndpoint] = None, default_only: typing.Optional[bool] = None, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ListModelsResponse]: """ Returns a list of models available for use. Parameters ---------- page_size : typing.Optional[float] Maximum number of models to include in a page Defaults to `20`, min value of `1`, max value of `1000`. page_token : typing.Optional[str] Page token provided in the `next_page_token` field of a previous response. endpoint : typing.Optional[CompatibleEndpoint] When provided, filters the list of models to only those that are compatible with the specified endpoint. default_only : typing.Optional[bool] When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ListModelsResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/models", method="GET", params={ "page_size": page_size, "page_token": page_token, "endpoint": endpoint, "default_only": default_only, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListModelsResponse, construct_type( type_=ListModelsResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawModelsClient: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper async def get( self, model: str, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[GetModelResponse]: """ Returns the details of a model, provided its name. Parameters ---------- model : str request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[GetModelResponse] OK """ _response = await self._client_wrapper.httpx_client.request( f"v1/models/{jsonable_encoder(model)}", method="GET", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( GetModelResponse, construct_type( type_=GetModelResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def list( self, *, page_size: typing.Optional[float] = None, page_token: typing.Optional[str] = None, endpoint: typing.Optional[CompatibleEndpoint] = None, default_only: typing.Optional[bool] = None, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ListModelsResponse]: """ Returns a list of models available for use. Parameters ---------- page_size : typing.Optional[float] Maximum number of models to include in a page Defaults to `20`, min value of `1`, max value of `1000`. page_token : typing.Optional[str] Page token provided in the `next_page_token` field of a previous response. endpoint : typing.Optional[CompatibleEndpoint] When provided, filters the list of models to only those that are compatible with the specified endpoint. default_only : typing.Optional[bool] When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ListModelsResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/models", method="GET", params={ "page_size": page_size, "page_token": page_token, "endpoint": endpoint, "default_only": default_only, }, request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ListModelsResponse, construct_type( type_=ListModelsResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/oci_client.py ================================================ """Oracle Cloud Infrastructure (OCI) client for Cohere API.""" import configparser import email.utils import json import os import typing import uuid import httpx import requests from .client import Client, ClientEnvironment from .client_v2 import ClientV2 from .aws_client import Streamer from .manually_maintained.lazy_oci_deps import lazy_oci from httpx import URL, ByteStream class OciClient(Client): """ Cohere V1 API client for Oracle Cloud Infrastructure (OCI) Generative AI service. Use this client for V1 API models (Command R family) and embeddings. For V2 API models (Command A family), use OciClientV2 instead. Supported APIs on OCI: - embed(): Full support for all embedding models - chat(): Full support with Command-R models - chat_stream(): Streaming chat support Supports all authentication methods: - Config file (default): Uses ~/.oci/config - Session-based: Uses OCI CLI session tokens - Direct credentials: Pass OCI credentials directly - Instance principal: For OCI compute instances - Resource principal: For OCI functions Example: ```python import cohere client = cohere.OciClient( oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) response = client.chat( model="command-r-08-2024", message="Hello!", ) print(response.text) ``` """ def __init__( self, *, oci_config_path: typing.Optional[str] = None, oci_profile: typing.Optional[str] = None, oci_user_id: typing.Optional[str] = None, oci_fingerprint: typing.Optional[str] = None, oci_tenancy_id: typing.Optional[str] = None, oci_private_key_path: typing.Optional[str] = None, oci_private_key_content: typing.Optional[str] = None, auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key", oci_region: typing.Optional[str] = None, oci_compartment_id: str, timeout: typing.Optional[float] = None, ): oci_config = _load_oci_config( auth_type=auth_type, config_path=oci_config_path, profile=oci_profile, user_id=oci_user_id, fingerprint=oci_fingerprint, tenancy_id=oci_tenancy_id, private_key_path=oci_private_key_path, private_key_content=oci_private_key_content, ) if oci_region is None: oci_region = oci_config.get("region") if oci_region is None: raise ValueError("oci_region must be provided either directly or in OCI config file") Client.__init__( self, base_url="https://api.cohere.com", environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( oci_config=oci_config, oci_region=oci_region, oci_compartment_id=oci_compartment_id, is_v2_client=False, ), timeout=timeout, ), ) class OciClientV2(ClientV2): """ Cohere V2 API client for Oracle Cloud Infrastructure (OCI) Generative AI service. Supported APIs on OCI: - embed(): Full support for all embedding models (returns embeddings as dict) - chat(): Full support with Command-A models (command-a-03-2025) - chat_stream(): Streaming chat with proper V2 event format Note: rerank() requires fine-tuned models deployed to dedicated endpoints. OCI on-demand inference does not support the rerank API. Supports all authentication methods: - Config file (default): Uses ~/.oci/config - Session-based: Uses OCI CLI session tokens - Direct credentials: Pass OCI credentials directly - Instance principal: For OCI compute instances - Resource principal: For OCI functions Example using config file: ```python import cohere client = cohere.OciClientV2( oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) response = client.embed( model="embed-english-v3.0", texts=["Hello world"], input_type="search_document", ) print(response.embeddings.float_) response = client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "Hello!"}], ) print(response.message) ``` Example using direct credentials: ```python client = cohere.OciClientV2( oci_user_id="ocid1.user.oc1...", oci_fingerprint="xx:xx:xx:...", oci_tenancy_id="ocid1.tenancy.oc1...", oci_private_key_path="~/.oci/key.pem", oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` Example using instance principal: ```python client = cohere.OciClientV2( auth_type="instance_principal", oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1...", ) ``` """ def __init__( self, *, # Authentication - Config file (default) oci_config_path: typing.Optional[str] = None, oci_profile: typing.Optional[str] = None, # Authentication - Direct credentials oci_user_id: typing.Optional[str] = None, oci_fingerprint: typing.Optional[str] = None, oci_tenancy_id: typing.Optional[str] = None, oci_private_key_path: typing.Optional[str] = None, oci_private_key_content: typing.Optional[str] = None, # Authentication - Instance principal auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key", # Required for OCI Generative AI oci_region: typing.Optional[str] = None, oci_compartment_id: str, # Standard parameters timeout: typing.Optional[float] = None, ): # Load OCI config based on auth_type oci_config = _load_oci_config( auth_type=auth_type, config_path=oci_config_path, profile=oci_profile, user_id=oci_user_id, fingerprint=oci_fingerprint, tenancy_id=oci_tenancy_id, private_key_path=oci_private_key_path, private_key_content=oci_private_key_content, ) # Get region from config if not provided if oci_region is None: oci_region = oci_config.get("region") if oci_region is None: raise ValueError("oci_region must be provided either directly or in OCI config file") # Create httpx client with OCI event hooks ClientV2.__init__( self, base_url="https://api.cohere.com", # Unused, OCI URL set in hooks environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( oci_config=oci_config, oci_region=oci_region, oci_compartment_id=oci_compartment_id, is_v2_client=True, ), timeout=timeout, ), ) EventHook = typing.Callable[..., typing.Any] def _load_oci_config( auth_type: str, config_path: typing.Optional[str], profile: typing.Optional[str], **kwargs: typing.Any, ) -> typing.Dict[str, typing.Any]: """ Load OCI configuration based on authentication type. Args: auth_type: Authentication method (api_key, instance_principal, resource_principal) config_path: Path to OCI config file (for api_key auth) profile: Profile name in config file (for api_key auth) **kwargs: Direct credentials (user_id, fingerprint, etc.) Returns: Dictionary containing OCI configuration """ oci = lazy_oci() if auth_type == "instance_principal": signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() return {"signer": signer, "auth_type": "instance_principal"} elif auth_type == "resource_principal": signer = oci.auth.signers.get_resource_principals_signer() return {"signer": signer, "auth_type": "resource_principal"} elif kwargs.get("user_id"): # Direct credentials provided - validate required fields required_fields = ["fingerprint", "tenancy_id"] missing = [f for f in required_fields if not kwargs.get(f)] if missing: raise ValueError( f"When providing oci_user_id, you must also provide: {', '.join('oci_' + f for f in missing)}" ) if not kwargs.get("private_key_path") and not kwargs.get("private_key_content"): raise ValueError( "When providing oci_user_id, you must also provide either " "oci_private_key_path or oci_private_key_content" ) config = { "user": kwargs["user_id"], "fingerprint": kwargs["fingerprint"], "tenancy": kwargs["tenancy_id"], } if kwargs.get("private_key_path"): config["key_file"] = kwargs["private_key_path"] if kwargs.get("private_key_content"): config["key_content"] = kwargs["private_key_content"] return config else: # Load from config file oci_config = oci.config.from_file( file_location=config_path or "~/.oci/config", profile_name=profile or "DEFAULT" ) _remove_inherited_session_auth(oci_config, config_path=config_path, profile=profile) return oci_config def _remove_inherited_session_auth( oci_config: typing.Dict[str, typing.Any], *, config_path: typing.Optional[str], profile: typing.Optional[str], ) -> None: """Drop session auth fields inherited from the OCI config DEFAULT section.""" profile_name = profile or "DEFAULT" if profile_name == "DEFAULT" or "security_token_file" not in oci_config: return config_file = os.path.expanduser(config_path or "~/.oci/config") parser = configparser.ConfigParser(interpolation=None) if not parser.read(config_file): return if not parser.has_section(profile_name): oci_config.pop("security_token_file", None) return explicit_security_token = False current_section: typing.Optional[str] = None with open(config_file, encoding="utf-8") as handle: for raw_line in handle: line = raw_line.strip() if not line or line.startswith(("#", ";")): continue if line.startswith("[") and line.endswith("]"): current_section = line[1:-1].strip() continue if current_section == profile_name and line.split("=", 1)[0].strip() == "security_token_file": explicit_security_token = True break if not explicit_security_token: oci_config.pop("security_token_file", None) def _usage_from_oci(usage_data: typing.Optional[typing.Dict[str, typing.Any]]) -> typing.Dict[str, typing.Any]: usage_data = usage_data or {} input_tokens = usage_data.get("inputTokens", 0) output_tokens = usage_data.get("completionTokens", usage_data.get("outputTokens", 0)) return { "tokens": { "input_tokens": input_tokens, "output_tokens": output_tokens, }, "billed_units": { "input_tokens": input_tokens, "output_tokens": output_tokens, } } def get_event_hooks( oci_config: typing.Dict[str, typing.Any], oci_region: str, oci_compartment_id: str, is_v2_client: bool = False, ) -> typing.Dict[str, typing.List[EventHook]]: """ Create httpx event hooks for OCI request/response transformation. Args: oci_config: OCI configuration dictionary oci_region: OCI region (e.g., "us-chicago-1") oci_compartment_id: OCI compartment OCID is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False) Returns: Dictionary of event hooks for httpx """ return { "request": [ map_request_to_oci( oci_config=oci_config, oci_region=oci_region, oci_compartment_id=oci_compartment_id, is_v2_client=is_v2_client, ), ], "response": [map_response_from_oci()], } def map_request_to_oci( oci_config: typing.Dict[str, typing.Any], oci_region: str, oci_compartment_id: str, is_v2_client: bool = False, ) -> EventHook: """ Create event hook that transforms Cohere requests to OCI format and signs them. Args: oci_config: OCI configuration dictionary oci_region: OCI region oci_compartment_id: OCI compartment OCID is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False) Returns: Event hook function for httpx """ oci = lazy_oci() # Create OCI signer based on config type # Priority order: instance/resource principal > session-based auth > API key auth if "signer" in oci_config: signer = oci_config["signer"] # Instance/resource principal elif "security_token_file" in oci_config: # Session-based authentication with security token. # The token file is re-read on every request so that OCI CLI token refreshes # (e.g. `oci session refresh`) are picked up without restarting the client. key_file = oci_config.get("key_file") if not key_file: raise ValueError( "OCI config profile is missing 'key_file'. " "Session-based auth requires a key_file entry in your OCI config profile." ) token_file_path = os.path.expanduser(oci_config["security_token_file"]) private_key = oci.signer.load_private_key_from_file(os.path.expanduser(key_file)) class _RefreshingSecurityTokenSigner: """Wraps SecurityTokenSigner and re-reads the token file before each signing call.""" def __init__(self) -> None: self._token_file = token_file_path self._private_key = private_key self._refresh() def _refresh(self) -> None: with open(self._token_file, "r") as _f: _token = _f.read().strip() self._signer = oci.auth.signers.SecurityTokenSigner( token=_token, private_key=self._private_key, ) # Delegate all attribute access to the inner signer, refreshing first. def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: self._refresh() return self._signer(*args, **kwargs) def __getattr__(self, name: str) -> typing.Any: if name.startswith("_"): raise AttributeError(name) self._refresh() return getattr(self._signer, name) signer = _RefreshingSecurityTokenSigner() elif "user" in oci_config: signer = oci.signer.Signer( tenancy=oci_config["tenancy"], user=oci_config["user"], fingerprint=oci_config["fingerprint"], private_key_file_location=oci_config.get("key_file"), private_key_content=oci_config.get("key_content"), ) else: # Config doesn't have user or security token - unsupported raise ValueError( "OCI config is missing 'user' field and no security_token_file found. " "Please use a profile with standard API key authentication, " "session-based authentication, or provide direct credentials via oci_user_id parameter." ) def _event_hook(request: httpx.Request) -> None: # Extract Cohere API details path_parts = request.url.path.split("/") endpoint = path_parts[-1] body = json.loads(request.read()) # Build OCI URL url = get_oci_url( region=oci_region, endpoint=endpoint, ) # Transform request body to OCI format oci_body = transform_request_to_oci( endpoint=endpoint, cohere_body=body, compartment_id=oci_compartment_id, is_v2=is_v2_client, ) # Prepare request for signing oci_body_bytes = json.dumps(oci_body).encode("utf-8") # Build headers for signing headers = { "content-type": "application/json", "date": email.utils.formatdate(usegmt=True), } # Create a requests.PreparedRequest for OCI signing oci_request = requests.Request( method=request.method, url=url, headers=headers, data=oci_body_bytes, ) prepped_request = oci_request.prepare() # Sign the request using OCI signer (modifies headers in place) signer.do_request_sign(prepped_request) # Update httpx request with signed headers request.url = URL(url) request.headers = httpx.Headers(prepped_request.headers) request.stream = ByteStream(oci_body_bytes) request._content = oci_body_bytes request.extensions["endpoint"] = endpoint request.extensions["is_stream"] = body.get("stream", False) request.extensions["is_v2"] = is_v2_client return _event_hook def map_response_from_oci() -> EventHook: """ Create event hook that transforms OCI responses to Cohere format. Returns: Event hook function for httpx """ def _hook(response: httpx.Response) -> None: endpoint = response.request.extensions["endpoint"] is_stream = response.request.extensions.get("is_stream", False) is_v2 = response.request.extensions.get("is_v2", False) output: typing.Iterator[bytes] # Only transform successful responses (200-299) # Let error responses pass through unchanged so SDK error handling works if not (200 <= response.status_code < 300): return # For streaming responses, wrap the stream with a transformer if is_stream: original_stream = typing.cast(typing.Iterator[bytes], response.stream) transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint, is_v2) response.stream = Streamer(transformed_stream) # Reset consumption flags if hasattr(response, "_content"): del response._content response.is_stream_consumed = False response.is_closed = False return # Handle non-streaming responses oci_response = json.loads(response.read()) cohere_response = transform_oci_response_to_cohere(endpoint, oci_response, is_v2) output = iter([json.dumps(cohere_response).encode("utf-8")]) response.stream = Streamer(output) # Reset response for re-reading if hasattr(response, "_content"): del response._content response.is_stream_consumed = False response.is_closed = False return _hook def get_oci_url( region: str, endpoint: str, ) -> str: """ Map Cohere endpoints to OCI Generative AI endpoints. Args: region: OCI region (e.g., "us-chicago-1") endpoint: Cohere endpoint name Returns: Full OCI Generative AI endpoint URL """ base = f"https://inference.generativeai.{region}.oci.oraclecloud.com" api_version = "20231130" # Map Cohere endpoints to OCI actions action_map = { "embed": "embedText", "chat": "chat", } action = action_map.get(endpoint) if action is None: raise ValueError( f"Endpoint '{endpoint}' is not supported by OCI Generative AI. " f"Supported endpoints: {list(action_map.keys())}" ) return f"{base}/{api_version}/actions/{action}" def normalize_model_for_oci(model: str) -> str: """ Normalize model name for OCI. OCI accepts model names in the format "cohere.model-name" or full OCIDs. This function ensures proper formatting for all regions. Args: model: Model name (e.g., "command-r-08-2024") or full OCID Returns: Normalized model identifier (e.g., "cohere.command-r-08-2024" or OCID) Examples: >>> normalize_model_for_oci("command-a-03-2025") "cohere.command-a-03-2025" >>> normalize_model_for_oci("cohere.embed-english-v3.0") "cohere.embed-english-v3.0" >>> normalize_model_for_oci("ocid1.generativeaimodel.oc1...") "ocid1.generativeaimodel.oc1..." """ if not model: raise ValueError("OCI requests require a non-empty model name") # If it's already an OCID, return as-is (works across all regions) if model.startswith("ocid1."): return model # Add "cohere." prefix if not present if not model.startswith("cohere."): return f"cohere.{model}" return model def transform_request_to_oci( endpoint: str, cohere_body: typing.Dict[str, typing.Any], compartment_id: str, is_v2: bool = False, ) -> typing.Dict[str, typing.Any]: """ Transform Cohere request body to OCI format. Args: endpoint: Cohere endpoint name cohere_body: Original Cohere request body compartment_id: OCI compartment OCID is_v2: Whether this request comes from OciClientV2 (True) or OciClient (False) Returns: Transformed request body in OCI format """ model = normalize_model_for_oci(cohere_body.get("model", "")) if endpoint == "embed": if "texts" in cohere_body: inputs = cohere_body["texts"] elif "inputs" in cohere_body: inputs = cohere_body["inputs"] elif "images" in cohere_body: raise ValueError("OCI embed does not support the top-level 'images' parameter; use 'inputs' instead") else: raise ValueError("OCI embed requires either 'texts' or 'inputs'") oci_body = { "inputs": inputs, "servingMode": { "servingType": "ON_DEMAND", "modelId": model, }, "compartmentId": compartment_id, } # Add optional fields only if provided if "input_type" in cohere_body: oci_body["inputType"] = cohere_body["input_type"].upper() if "truncate" in cohere_body: oci_body["truncate"] = cohere_body["truncate"].upper() if "embedding_types" in cohere_body: # OCI expects lowercase embedding types (float, int8, binary, etc.) oci_body["embeddingTypes"] = [et.lower() for et in cohere_body["embedding_types"]] if "max_tokens" in cohere_body: oci_body["maxTokens"] = cohere_body["max_tokens"] if "output_dimension" in cohere_body: oci_body["outputDimension"] = cohere_body["output_dimension"] if "priority" in cohere_body: oci_body["priority"] = cohere_body["priority"] return oci_body elif endpoint == "chat": # Validate that the request body matches the client type has_messages = "messages" in cohere_body has_message = "message" in cohere_body if is_v2 and not has_messages: raise ValueError( "OciClientV2 requires the V2 API format ('messages' array). " "Got a V1-style request with 'message' string. " "Use OciClient for V1 models like Command R, " "or switch to the V2 messages format." ) if not is_v2 and has_messages and not has_message: raise ValueError( "OciClient uses the V1 API format (single 'message' string). " "Got a V2-style request with 'messages' array. " "Use OciClientV2 for V2 models like Command A." ) chat_request: typing.Dict[str, typing.Any] = { "apiFormat": "COHEREV2" if is_v2 else "COHERE", } if is_v2: # V2: Transform Cohere V2 messages to OCI V2 format # Cohere sends: [{"role": "user", "content": "text"}] # OCI expects: [{"role": "USER", "content": [{"type": "TEXT", "text": "..."}]}] oci_messages = [] for msg in cohere_body["messages"]: oci_msg: typing.Dict[str, typing.Any] = { "role": msg["role"].upper(), } # Transform content if isinstance(msg.get("content"), str): oci_msg["content"] = [{"type": "TEXT", "text": msg["content"]}] elif isinstance(msg.get("content"), list): transformed_content = [] for item in msg["content"]: if isinstance(item, dict) and "type" in item: transformed_item = item.copy() transformed_item["type"] = item["type"].upper() # OCI expects camelCase: image_url → imageUrl if "image_url" in transformed_item: transformed_item["imageUrl"] = transformed_item.pop("image_url") transformed_content.append(transformed_item) else: transformed_content.append(item) oci_msg["content"] = transformed_content else: oci_msg["content"] = msg.get("content") or [] if "tool_calls" in msg: oci_tool_calls = [] for tc in msg["tool_calls"]: oci_tc = {**tc} if "type" in oci_tc: oci_tc["type"] = oci_tc["type"].upper() oci_tool_calls.append(oci_tc) oci_msg["toolCalls"] = oci_tool_calls if "tool_call_id" in msg: oci_msg["toolCallId"] = msg["tool_call_id"] if "tool_plan" in msg: oci_msg["toolPlan"] = msg["tool_plan"] oci_messages.append(oci_msg) chat_request["messages"] = oci_messages # V2 optional parameters if "max_tokens" in cohere_body: chat_request["maxTokens"] = cohere_body["max_tokens"] if "temperature" in cohere_body: chat_request["temperature"] = cohere_body["temperature"] if "k" in cohere_body: chat_request["topK"] = cohere_body["k"] if "p" in cohere_body: chat_request["topP"] = cohere_body["p"] if "seed" in cohere_body: chat_request["seed"] = cohere_body["seed"] if "frequency_penalty" in cohere_body: chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"] if "presence_penalty" in cohere_body: chat_request["presencePenalty"] = cohere_body["presence_penalty"] if "stop_sequences" in cohere_body: chat_request["stopSequences"] = cohere_body["stop_sequences"] if "tools" in cohere_body: oci_tools = [] for tool in cohere_body["tools"]: oci_tool = {**tool} if "type" in oci_tool: oci_tool["type"] = oci_tool["type"].upper() oci_tools.append(oci_tool) chat_request["tools"] = oci_tools if "strict_tools" in cohere_body: chat_request["strictTools"] = cohere_body["strict_tools"] if "documents" in cohere_body: chat_request["documents"] = cohere_body["documents"] if "citation_options" in cohere_body: chat_request["citationOptions"] = cohere_body["citation_options"] if "response_format" in cohere_body: chat_request["responseFormat"] = cohere_body["response_format"] if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None: chat_request["safetyMode"] = cohere_body["safety_mode"].upper() if "logprobs" in cohere_body: chat_request["logprobs"] = cohere_body["logprobs"] if "tool_choice" in cohere_body: chat_request["toolChoice"] = cohere_body["tool_choice"] if "priority" in cohere_body: chat_request["priority"] = cohere_body["priority"] # Thinking parameter for Command A Reasoning models if "thinking" in cohere_body and cohere_body["thinking"] is not None: thinking = cohere_body["thinking"] oci_thinking: typing.Dict[str, typing.Any] = {} if "type" in thinking: oci_thinking["type"] = thinking["type"].upper() if "token_budget" in thinking and thinking["token_budget"] is not None: oci_thinking["tokenBudget"] = thinking["token_budget"] if oci_thinking: chat_request["thinking"] = oci_thinking else: # V1: single message string chat_request["message"] = cohere_body["message"] if "temperature" in cohere_body: chat_request["temperature"] = cohere_body["temperature"] if "max_tokens" in cohere_body: chat_request["maxTokens"] = cohere_body["max_tokens"] if "k" in cohere_body: chat_request["topK"] = cohere_body["k"] if "p" in cohere_body: chat_request["topP"] = cohere_body["p"] if "seed" in cohere_body: chat_request["seed"] = cohere_body["seed"] if "stop_sequences" in cohere_body: chat_request["stopSequences"] = cohere_body["stop_sequences"] if "frequency_penalty" in cohere_body: chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"] if "presence_penalty" in cohere_body: chat_request["presencePenalty"] = cohere_body["presence_penalty"] if "preamble" in cohere_body: chat_request["preambleOverride"] = cohere_body["preamble"] if "chat_history" in cohere_body: chat_request["chatHistory"] = cohere_body["chat_history"] if "documents" in cohere_body: chat_request["documents"] = cohere_body["documents"] if "tools" in cohere_body: oci_tools = [] for tool in cohere_body["tools"]: oci_tool = {**tool} if "type" in oci_tool: oci_tool["type"] = oci_tool["type"].upper() oci_tools.append(oci_tool) chat_request["tools"] = oci_tools if "tool_results" in cohere_body: chat_request["toolResults"] = cohere_body["tool_results"] if "response_format" in cohere_body: chat_request["responseFormat"] = cohere_body["response_format"] if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None: chat_request["safetyMode"] = cohere_body["safety_mode"].upper() if "priority" in cohere_body: chat_request["priority"] = cohere_body["priority"] # Handle streaming for both versions if cohere_body.get("stream"): chat_request["isStream"] = True # Top level OCI request structure oci_body = { "servingMode": { "servingType": "ON_DEMAND", "modelId": model, }, "compartmentId": compartment_id, "chatRequest": chat_request, } return oci_body raise ValueError( f"Endpoint '{endpoint}' is not supported by OCI Generative AI on-demand inference. " "Supported endpoints: ['embed', 'chat']" ) def transform_oci_response_to_cohere( endpoint: str, oci_response: typing.Dict[str, typing.Any], is_v2: bool = False, ) -> typing.Dict[str, typing.Any]: """ Transform OCI response to Cohere format. Args: endpoint: Cohere endpoint name oci_response: OCI response body is_v2: Whether this is a V2 API response Returns: Transformed response in Cohere format """ if endpoint == "embed": # OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified embeddings_data = oci_response.get("embeddingsByType") or oci_response.get("embeddings", {}) if isinstance(embeddings_data, dict): normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()} else: normalized_embeddings = {"float": embeddings_data} if is_v2: embeddings = normalized_embeddings else: embeddings = normalized_embeddings.get("float", []) meta = { "api_version": {"version": "1"}, } usage = _usage_from_oci(oci_response.get("usage")) if "tokens" in usage: meta["tokens"] = usage["tokens"] if "billed_units" in usage: meta["billed_units"] = usage["billed_units"] response_type = "embeddings_by_type" if is_v2 else "embeddings_floats" return { "response_type": response_type, "id": oci_response.get("id", str(uuid.uuid4())), "embeddings": embeddings, "texts": [], "meta": meta, } elif endpoint == "chat": chat_response = oci_response.get("chatResponse", {}) if is_v2: usage = _usage_from_oci(chat_response.get("usage")) message = chat_response.get("message", {}) if "role" in message: message = {**message, "role": message["role"].lower()} if "content" in message and isinstance(message["content"], list): transformed_content = [] for item in message["content"]: if isinstance(item, dict): transformed_item = item.copy() if "type" in transformed_item: transformed_item["type"] = transformed_item["type"].lower() transformed_content.append(transformed_item) else: transformed_content.append(item) message = {**message, "content": transformed_content} if "toolCalls" in message: tool_calls = [] for tc in message["toolCalls"]: lowered_tc = {**tc} if "type" in lowered_tc: lowered_tc["type"] = lowered_tc["type"].lower() tool_calls.append(lowered_tc) message = {k: v for k, v in message.items() if k != "toolCalls"} message["tool_calls"] = tool_calls if "toolPlan" in message: tool_plan = message["toolPlan"] message = {k: v for k, v in message.items() if k != "toolPlan"} message["tool_plan"] = tool_plan return { "id": chat_response.get("id", str(uuid.uuid4())), "message": message, "finish_reason": chat_response.get("finishReason", "COMPLETE"), "usage": usage, } # V1 response meta = { "api_version": {"version": "1"}, } usage = _usage_from_oci(chat_response.get("usage")) if "tokens" in usage: meta["tokens"] = usage["tokens"] if "billed_units" in usage: meta["billed_units"] = usage["billed_units"] return { "text": chat_response.get("text", ""), "generation_id": str(uuid.uuid4()), "chat_history": chat_response.get("chatHistory", []), "finish_reason": chat_response.get("finishReason", "COMPLETE"), "citations": chat_response.get("citations", []), "documents": chat_response.get("documents", []), "search_queries": chat_response.get("searchQueries", []), "meta": meta, } return oci_response def transform_oci_stream_wrapper( stream: typing.Iterator[bytes], endpoint: str, is_v2: bool = False, ) -> typing.Iterator[bytes]: """ Wrap OCI stream and transform events to Cohere format. Args: stream: Original OCI stream iterator endpoint: Cohere endpoint name is_v2: Whether this is a V2 API stream Yields: Bytes of transformed streaming events """ generation_id = str(uuid.uuid4()) emitted_start = False emitted_content_end = False current_content_type: typing.Optional[str] = None current_content_index = 0 final_finish_reason = "COMPLETE" final_usage: typing.Optional[typing.Dict[str, typing.Any]] = None full_v1_text = "" final_v1_finish_reason = "COMPLETE" buffer = b"" def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes: return b"data: " + json.dumps(event).encode("utf-8") + b"\n\n" def _emit_v1_event(event: typing.Dict[str, typing.Any]) -> bytes: return json.dumps(event).encode("utf-8") + b"\n" def _current_content_type(oci_event: typing.Dict[str, typing.Any]) -> typing.Optional[str]: message = oci_event.get("message") if isinstance(message, dict): content_list = message.get("content") if content_list and isinstance(content_list, list) and len(content_list) > 0: oci_type = content_list[0].get("type", "TEXT").upper() return "thinking" if oci_type == "THINKING" else "text" return None # finish-only or non-content event — don't trigger a type transition def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]: nonlocal emitted_start, emitted_content_end, current_content_type, current_content_index nonlocal final_finish_reason, final_usage event_content_type = _current_content_type(oci_event) open_type = event_content_type or "text" if not emitted_start: yield _emit_v2_event( { "type": "message-start", "id": generation_id, "delta": {"message": {"role": "assistant"}}, } ) yield _emit_v2_event( { "type": "content-start", "index": current_content_index, "delta": {"message": {"content": {"type": open_type}}}, } ) emitted_start = True current_content_type = open_type elif event_content_type is not None and current_content_type != event_content_type: yield _emit_v2_event({"type": "content-end", "index": current_content_index}) current_content_index += 1 yield _emit_v2_event( { "type": "content-start", "index": current_content_index, "delta": {"message": {"content": {"type": event_content_type}}}, } ) current_content_type = event_content_type emitted_content_end = False for cohere_event in typing.cast( typing.List[typing.Dict[str, typing.Any]], transform_stream_event(endpoint, oci_event, is_v2=True) ): if "index" in cohere_event: cohere_event = {**cohere_event, "index": current_content_index} if cohere_event["type"] == "content-end": emitted_content_end = True final_finish_reason = oci_event.get("finishReason", final_finish_reason) final_usage = _usage_from_oci(oci_event.get("usage")) yield _emit_v2_event(cohere_event) def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]: nonlocal emitted_start, full_v1_text, final_v1_finish_reason if not emitted_start: yield _emit_v1_event({ "event_type": "stream-start", "generation_id": generation_id, "is_finished": False, }) emitted_start = True event = transform_stream_event(endpoint, oci_event, is_v2=False) if isinstance(event, dict): if event.get("event_type") == "text-generation" and event.get("text"): full_v1_text += typing.cast(str, event["text"]) if "finishReason" in oci_event: final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason) yield _emit_v1_event(event) stream_finished = False def _emit_closing_events() -> typing.Iterator[bytes]: """Emit the final closing events for the stream.""" if is_v2: if emitted_start: if not emitted_content_end: yield _emit_v2_event({"type": "content-end", "index": current_content_index}) message_end_event: typing.Dict[str, typing.Any] = { "type": "message-end", "id": generation_id, "delta": {"finish_reason": final_finish_reason}, } if final_usage: message_end_event["delta"]["usage"] = final_usage yield _emit_v2_event(message_end_event) else: yield _emit_v1_event( { "event_type": "stream-end", "finish_reason": final_v1_finish_reason, "response": { "text": full_v1_text, "generation_id": generation_id, "finish_reason": final_v1_finish_reason, }, } ) def _process_line(line: str) -> typing.Iterator[bytes]: nonlocal stream_finished if not line.startswith("data: "): return data_str = line[6:] if data_str.strip() == "[DONE]": for event_bytes in _emit_closing_events(): yield event_bytes stream_finished = True return try: oci_event = json.loads(data_str) except json.JSONDecodeError: return try: if is_v2: for event_bytes in _transform_v2_event(oci_event): yield event_bytes else: for event_bytes in _transform_v1_event(oci_event): yield event_bytes except Exception as exc: raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc # OCI may not send [DONE] — treat finishReason as stream termination if "finishReason" in oci_event: for event_bytes in _emit_closing_events(): yield event_bytes stream_finished = True for chunk in stream: buffer += chunk while b"\n" in buffer: line_bytes, buffer = buffer.split(b"\n", 1) line = line_bytes.decode("utf-8").strip() for event_bytes in _process_line(line): yield event_bytes if stream_finished: return if buffer.strip() and not stream_finished: line = buffer.decode("utf-8").strip() for event_bytes in _process_line(line): yield event_bytes def transform_stream_event( endpoint: str, oci_event: typing.Dict[str, typing.Any], is_v2: bool = False, ) -> typing.Union[typing.Dict[str, typing.Any], typing.List[typing.Dict[str, typing.Any]]]: """ Transform individual OCI stream event to Cohere format. Args: endpoint: Cohere endpoint name oci_event: OCI stream event is_v2: Whether this is a V2 API stream Returns: V2: List of transformed events. V1: Single transformed event dict. """ if endpoint == "chat": if is_v2: content_type = "text" content_value = "" message = oci_event.get("message") if "message" in oci_event and not isinstance(message, dict): raise TypeError("OCI V2 stream event message must be an object") if isinstance(message, dict) and "content" in message: content_list = message["content"] if content_list and isinstance(content_list, list) and len(content_list) > 0: first_content = content_list[0] oci_type = first_content.get("type", "TEXT").upper() if oci_type == "THINKING": content_type = "thinking" content_value = first_content.get("thinking", "") else: content_type = "text" content_value = first_content.get("text", "") events: typing.List[typing.Dict[str, typing.Any]] = [] if content_value: delta_content: typing.Dict[str, typing.Any] = {} if content_type == "thinking": delta_content["thinking"] = content_value else: delta_content["text"] = content_value events.append( { "type": "content-delta", "index": 0, "delta": { "message": { "content": delta_content, } }, } ) if "finishReason" in oci_event: events.append( { "type": "content-end", "index": 0, } ) return events # V1 stream event return { "event_type": "text-generation", "text": oci_event.get("text", ""), "is_finished": oci_event.get("isFinished", False), } return [] if is_v2 else {} ================================================ FILE: src/cohere/overrides.py ================================================ import typing import uuid from . import EmbedByTypeResponseEmbeddings from .core.pydantic_utilities import _get_model_fields, Model, IS_PYDANTIC_V2 from pprint import pprint def get_fields(obj) -> typing.List[str]: return [str(x) for x in _get_model_fields(obj).keys()] def get_aliases_or_field(obj) -> typing.List[str]: return [ field_info.alias or (field_info and field_info.metadata and field_info.metadata[0] and field_info.metadata[0].alias) or field_name # type: ignore for field_name, field_info in _get_model_fields(obj).items() ] def get_aliases_and_fields(obj): # merge and dedup get_fields(obj), get_aliases_or_field(obj) return list(set(get_fields(obj) + get_aliases_or_field(obj))) def allow_access_to_aliases(self: typing.Type["Model"], name): for field_name, field_info in _get_model_fields(self).items(): alias = field_info.alias or ( field_info and field_info.metadata and field_info.metadata[0] and field_info.metadata[0].alias) # type: ignore if alias == name or field_name == name: return getattr(self, field_name) raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'") def make_tool_call_v2_id_optional(cls): """ Override ToolCallV2 to make the 'id' field optional with a default UUID. This ensures backward compatibility with code that doesn't provide an id. We wrap the __init__ method to inject a default id before Pydantic validation runs. """ # Store the original __init__ method original_init = cls.__init__ def patched_init(self, /, **data): """Patched __init__ that injects default id if not provided.""" # Inject default UUID if 'id' is not in the data if 'id' not in data: data['id'] = str(uuid.uuid4()) # Call the original __init__ with modified data original_init(self, **data) # Replace the __init__ method cls.__init__ = patched_init return cls def run_overrides(): """ These are overrides to allow us to make changes to generated code without touching the generated files themselves. Should be used judiciously! """ # Override to allow access to aliases in EmbedByTypeResponseEmbeddings eg embeddings.float rather than embeddings.float_ setattr(EmbedByTypeResponseEmbeddings, "__getattr__", allow_access_to_aliases) # Import ToolCallV2 lazily to avoid circular dependency issues from . import ToolCallV2 # Override ToolCallV2 to make id field optional with default UUID make_tool_call_v2_id_optional(ToolCallV2) # Run overrides immediately at module import time to ensure they're applied # before any code tries to use the modified classes run_overrides() ================================================ FILE: src/cohere/py.typed ================================================ ================================================ FILE: src/cohere/raw_base_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import contextlib import json import typing from json.decoder import JSONDecodeError from .core.api_error import ApiError from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from .core.http_response import AsyncHttpResponse, HttpResponse from .core.parse_error import ParsingError from .core.request_options import RequestOptions from .core.serialization import convert_and_respect_annotation_metadata from .core.unchecked_base_model import construct_type from .errors.bad_request_error import BadRequestError from .errors.client_closed_request_error import ClientClosedRequestError from .errors.forbidden_error import ForbiddenError from .errors.gateway_timeout_error import GatewayTimeoutError from .errors.internal_server_error import InternalServerError from .errors.invalid_token_error import InvalidTokenError from .errors.not_found_error import NotFoundError from .errors.not_implemented_error import NotImplementedError from .errors.service_unavailable_error import ServiceUnavailableError from .errors.too_many_requests_error import TooManyRequestsError from .errors.unauthorized_error import UnauthorizedError from .errors.unprocessable_entity_error import UnprocessableEntityError from .types.chat_connector import ChatConnector from .types.chat_document import ChatDocument from .types.chat_request_citation_quality import ChatRequestCitationQuality from .types.chat_request_prompt_truncation import ChatRequestPromptTruncation from .types.chat_request_safety_mode import ChatRequestSafetyMode from .types.chat_stream_request_citation_quality import ChatStreamRequestCitationQuality from .types.chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation from .types.chat_stream_request_safety_mode import ChatStreamRequestSafetyMode from .types.check_api_key_response import CheckApiKeyResponse from .types.classify_example import ClassifyExample from .types.classify_request_truncate import ClassifyRequestTruncate from .types.classify_response import ClassifyResponse from .types.detokenize_response import DetokenizeResponse from .types.embed_input_type import EmbedInputType from .types.embed_request_truncate import EmbedRequestTruncate from .types.embed_response import EmbedResponse from .types.embedding_type import EmbeddingType from .types.generate_request_return_likelihoods import GenerateRequestReturnLikelihoods from .types.generate_request_truncate import GenerateRequestTruncate from .types.generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods from .types.generate_stream_request_truncate import GenerateStreamRequestTruncate from .types.generate_streamed_response import GenerateStreamedResponse from .types.generation import Generation from .types.message import Message from .types.non_streamed_chat_response import NonStreamedChatResponse from .types.rerank_request_documents_item import RerankRequestDocumentsItem from .types.rerank_response import RerankResponse from .types.response_format import ResponseFormat from .types.streamed_chat_response import StreamedChatResponse from .types.summarize_request_extractiveness import SummarizeRequestExtractiveness from .types.summarize_request_format import SummarizeRequestFormat from .types.summarize_request_length import SummarizeRequestLength from .types.summarize_response import SummarizeResponse from .types.tokenize_response import TokenizeResponse from .types.tool import Tool from .types.tool_result import ToolResult from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawBaseCohere: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper @contextlib.contextmanager def chat_stream( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[HttpResponse[typing.Iterator[StreamedChatResponse]]]: """ Generates a streamed text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatStreamRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[HttpResponse[typing.Iterator[StreamedChatResponse]]] """ with self._client_wrapper.httpx_client.stream( "v1/chat", method="POST", json={ "message": message, "model": model, "preamble": preamble, "chat_history": convert_and_respect_annotation_metadata( object_=chat_history, annotation=typing.Sequence[Message], direction="write" ), "conversation_id": conversation_id, "prompt_truncation": prompt_truncation, "connectors": convert_and_respect_annotation_metadata( object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write" ), "search_queries_only": search_queries_only, "documents": documents, "citation_quality": citation_quality, "temperature": temperature, "max_tokens": max_tokens, "max_input_tokens": max_input_tokens, "k": k, "p": p, "seed": seed, "stop_sequences": stop_sequences, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "raw_prompting": raw_prompting, "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[Tool], direction="write" ), "tool_results": convert_and_respect_annotation_metadata( object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write" ), "force_single_step": force_single_step, "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormat, direction="write" ), "safety_mode": safety_mode, "stream": True, }, headers={ "content-type": "application/json", "Accepts": str(accepts) if accepts is not None else None, }, request_options=request_options, omit=OMIT, ) as _response: def _stream() -> HttpResponse[typing.Iterator[StreamedChatResponse]]: try: if 200 <= _response.status_code < 300: def _iter(): for _text in _response.iter_lines(): try: if len(_text) == 0: continue yield typing.cast( StreamedChatResponse, construct_type( type_=StreamedChatResponse, # type: ignore object_=json.loads(_text), ), ) except Exception: pass return return HttpResponse(response=_response, data=_iter()) _response.read() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield _stream() def chat( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[NonStreamedChatResponse]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[NonStreamedChatResponse] """ _response = self._client_wrapper.httpx_client.request( "v1/chat", method="POST", json={ "message": message, "model": model, "preamble": preamble, "chat_history": convert_and_respect_annotation_metadata( object_=chat_history, annotation=typing.Sequence[Message], direction="write" ), "conversation_id": conversation_id, "prompt_truncation": prompt_truncation, "connectors": convert_and_respect_annotation_metadata( object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write" ), "search_queries_only": search_queries_only, "documents": documents, "citation_quality": citation_quality, "temperature": temperature, "max_tokens": max_tokens, "max_input_tokens": max_input_tokens, "k": k, "p": p, "seed": seed, "stop_sequences": stop_sequences, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "raw_prompting": raw_prompting, "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[Tool], direction="write" ), "tool_results": convert_and_respect_annotation_metadata( object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write" ), "force_single_step": force_single_step, "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormat, direction="write" ), "safety_mode": safety_mode, "stream": False, }, headers={ "content-type": "application/json", "Accepts": str(accepts) if accepts is not None else None, }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( NonStreamedChatResponse, construct_type( type_=NonStreamedChatResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @contextlib.contextmanager def generate_stream( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[HttpResponse[typing.Iterator[GenerateStreamedResponse]]]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateStreamRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[HttpResponse[typing.Iterator[GenerateStreamedResponse]]] """ with self._client_wrapper.httpx_client.stream( "v1/generate", method="POST", json={ "prompt": prompt, "model": model, "num_generations": num_generations, "max_tokens": max_tokens, "truncate": truncate, "temperature": temperature, "seed": seed, "preset": preset, "end_sequences": end_sequences, "stop_sequences": stop_sequences, "k": k, "p": p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "return_likelihoods": return_likelihoods, "raw_prompting": raw_prompting, "stream": True, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) as _response: def _stream() -> HttpResponse[typing.Iterator[GenerateStreamedResponse]]: try: if 200 <= _response.status_code < 300: def _iter(): for _text in _response.iter_lines(): try: if len(_text) == 0: continue yield typing.cast( GenerateStreamedResponse, construct_type( type_=GenerateStreamedResponse, # type: ignore object_=json.loads(_text), ), ) except Exception: pass return return HttpResponse(response=_response, data=_iter()) _response.read() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield _stream() def generate( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[Generation]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[Generation] """ _response = self._client_wrapper.httpx_client.request( "v1/generate", method="POST", json={ "prompt": prompt, "model": model, "num_generations": num_generations, "max_tokens": max_tokens, "truncate": truncate, "temperature": temperature, "seed": seed, "preset": preset, "end_sequences": end_sequences, "stop_sequences": stop_sequences, "k": k, "p": p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "return_likelihoods": return_likelihoods, "raw_prompting": raw_prompting, "stream": False, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( Generation, construct_type( type_=Generation, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[EmbedResponse]: """ This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents. Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Images are only supported with Embed v3.0 and newer models. model : typing.Optional[str] ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : typing.Optional[EmbedInputType] embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[EmbedResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/embed", method="POST", json={ "texts": texts, "images": images, "model": model, "input_type": input_type, "embedding_types": embedding_types, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedResponse, construct_type( type_=EmbedResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def rerank( self, *, query: str, documents: typing.Sequence[RerankRequestDocumentsItem], model: typing.Optional[str] = OMIT, top_n: typing.Optional[int] = OMIT, rank_fields: typing.Optional[typing.Sequence[str]] = OMIT, return_documents: typing.Optional[bool] = OMIT, max_chunks_per_doc: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[RerankResponse]: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- query : str The search query documents : typing.Sequence[RerankRequestDocumentsItem] A list of document objects or strings to rerank. If a document is provided the text fields is required and all other fields will be preserved in the response. The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. We recommend a maximum of 1,000 documents for optimal endpoint performance. model : typing.Optional[str] The identifier of the model to use, eg `rerank-v3.5`. top_n : typing.Optional[int] The number of most relevant documents or indices to return, defaults to the length of the documents rank_fields : typing.Optional[typing.Sequence[str]] If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking. return_documents : typing.Optional[bool] - If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request. - If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request. max_chunks_per_doc : typing.Optional[int] The maximum number of chunks to produce internally from a document request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[RerankResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/rerank", method="POST", json={ "model": model, "query": query, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[RerankRequestDocumentsItem], direction="write" ), "top_n": top_n, "rank_fields": rank_fields, "return_documents": return_documents, "max_chunks_per_doc": max_chunks_per_doc, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( RerankResponse, construct_type( type_=RerankResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def classify( self, *, inputs: typing.Sequence[str], examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT, model: typing.Optional[str] = OMIT, preset: typing.Optional[str] = OMIT, truncate: typing.Optional[ClassifyRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[ClassifyResponse]: """ This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. Parameters ---------- inputs : typing.Sequence[str] A list of up to 96 texts to be classified. Each one must be a non-empty string. There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models). Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts. examples : typing.Optional[typing.Sequence[ClassifyExample]] An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. model : typing.Optional[str] ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model preset : typing.Optional[str] The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters. truncate : typing.Optional[ClassifyRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[ClassifyResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/classify", method="POST", json={ "inputs": inputs, "examples": convert_and_respect_annotation_metadata( object_=examples, annotation=typing.Sequence[ClassifyExample], direction="write" ), "model": model, "preset": preset, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ClassifyResponse, construct_type( type_=ClassifyResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def summarize( self, *, text: str, length: typing.Optional[SummarizeRequestLength] = OMIT, format: typing.Optional[SummarizeRequestFormat] = OMIT, model: typing.Optional[str] = OMIT, extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT, temperature: typing.Optional[float] = OMIT, additional_command: typing.Optional[str] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[SummarizeResponse]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates a summary in English for a given text. Parameters ---------- text : str The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English. length : typing.Optional[SummarizeRequestLength] One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. format : typing.Optional[SummarizeRequestFormat] One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. model : typing.Optional[str] The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. extractiveness : typing.Optional[SummarizeRequestExtractiveness] One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. temperature : typing.Optional[float] Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1. additional_command : typing.Optional[str] A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda" request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[SummarizeResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/summarize", method="POST", json={ "text": text, "length": length, "format": format, "model": model, "extractiveness": extractiveness, "temperature": temperature, "additional_command": additional_command, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( SummarizeResponse, construct_type( type_=SummarizeResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[TokenizeResponse]: """ This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- text : str The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters. model : str The input will be tokenized by the tokenizer that is used by this model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[TokenizeResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/tokenize", method="POST", json={ "text": text, "model": model, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( TokenizeResponse, construct_type( type_=TokenizeResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[DetokenizeResponse]: """ This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- tokens : typing.Sequence[int] The list of tokens to be detokenized. model : str An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[DetokenizeResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/detokenize", method="POST", json={ "tokens": tokens, "model": model, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DetokenizeResponse, construct_type( type_=DetokenizeResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def check_api_key( self, *, request_options: typing.Optional[RequestOptions] = None ) -> HttpResponse[CheckApiKeyResponse]: """ Checks that the api key in the Authorization header is valid and active Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[CheckApiKeyResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v1/check-api-key", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CheckApiKeyResponse, construct_type( type_=CheckApiKeyResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawBaseCohere: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper @contextlib.asynccontextmanager async def chat_stream( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]]: """ Generates a streamed text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatStreamRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]] """ async with self._client_wrapper.httpx_client.stream( "v1/chat", method="POST", json={ "message": message, "model": model, "preamble": preamble, "chat_history": convert_and_respect_annotation_metadata( object_=chat_history, annotation=typing.Sequence[Message], direction="write" ), "conversation_id": conversation_id, "prompt_truncation": prompt_truncation, "connectors": convert_and_respect_annotation_metadata( object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write" ), "search_queries_only": search_queries_only, "documents": documents, "citation_quality": citation_quality, "temperature": temperature, "max_tokens": max_tokens, "max_input_tokens": max_input_tokens, "k": k, "p": p, "seed": seed, "stop_sequences": stop_sequences, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "raw_prompting": raw_prompting, "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[Tool], direction="write" ), "tool_results": convert_and_respect_annotation_metadata( object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write" ), "force_single_step": force_single_step, "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormat, direction="write" ), "safety_mode": safety_mode, "stream": True, }, headers={ "content-type": "application/json", "Accepts": str(accepts) if accepts is not None else None, }, request_options=request_options, omit=OMIT, ) as _response: async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]: try: if 200 <= _response.status_code < 300: async def _iter(): async for _text in _response.aiter_lines(): try: if len(_text) == 0: continue yield typing.cast( StreamedChatResponse, construct_type( type_=StreamedChatResponse, # type: ignore object_=json.loads(_text), ), ) except Exception: pass return return AsyncHttpResponse(response=_response, data=_iter()) await _response.aread() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield await _stream() async def chat( self, *, message: str, accepts: typing.Optional[typing.Literal["text/event-stream"]] = None, model: typing.Optional[str] = OMIT, preamble: typing.Optional[str] = OMIT, chat_history: typing.Optional[typing.Sequence[Message]] = OMIT, conversation_id: typing.Optional[str] = OMIT, prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT, connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT, search_queries_only: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT, citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT, temperature: typing.Optional[float] = OMIT, max_tokens: typing.Optional[int] = OMIT, max_input_tokens: typing.Optional[int] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, tools: typing.Optional[typing.Sequence[Tool]] = OMIT, tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT, force_single_step: typing.Optional[bool] = OMIT, response_format: typing.Optional[ResponseFormat] = OMIT, safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[NonStreamedChatResponse]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). Parameters ---------- message : str Text input for the model to respond to. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments accepts : typing.Optional[typing.Literal["text/event-stream"]] Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events. model : typing.Optional[str] The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Compatible Deployments: Cohere Platform, Private Deployments preamble : typing.Optional[str] When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments chat_history : typing.Optional[typing.Sequence[Message]] A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments conversation_id : typing.Optional[str] An alternative to `chat_history`. Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. Compatible Deployments: Cohere Platform prompt_truncation : typing.Optional[ChatRequestPromptTruncation] Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. Compatible Deployments: - AUTO: Cohere Platform Only - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments connectors : typing.Optional[typing.Sequence[ChatConnector]] Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one. When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). Compatible Deployments: Cohere Platform search_queries_only : typing.Optional[bool] Defaults to `false`. When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments documents : typing.Optional[typing.Sequence[ChatDocument]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. Example: ``` [ { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, ] ``` Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments citation_quality : typing.Optional[ChatRequestCitationQuality] Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments max_input_tokens : typing.Optional[int] The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer. Input will be truncated according to the `prompt_truncation` parameter. Compatible Deployments: Cohere Platform k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tools : typing.Optional[typing.Sequence[Tool]] A list of available tools (functions) that the model may suggest invoking before producing a text response. When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments tool_results : typing.Optional[typing.Sequence[ToolResult]] A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. ``` tool_results = [ { "call": { "name": , "parameters": { : } }, "outputs": [{ : }] }, ... ] ``` **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments force_single_step : typing.Optional[bool] Forces the chat to be single step. Defaults to `false`. response_format : typing.Optional[ResponseFormat] safety_mode : typing.Optional[ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `NONE` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[NonStreamedChatResponse] """ _response = await self._client_wrapper.httpx_client.request( "v1/chat", method="POST", json={ "message": message, "model": model, "preamble": preamble, "chat_history": convert_and_respect_annotation_metadata( object_=chat_history, annotation=typing.Sequence[Message], direction="write" ), "conversation_id": conversation_id, "prompt_truncation": prompt_truncation, "connectors": convert_and_respect_annotation_metadata( object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write" ), "search_queries_only": search_queries_only, "documents": documents, "citation_quality": citation_quality, "temperature": temperature, "max_tokens": max_tokens, "max_input_tokens": max_input_tokens, "k": k, "p": p, "seed": seed, "stop_sequences": stop_sequences, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "raw_prompting": raw_prompting, "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[Tool], direction="write" ), "tool_results": convert_and_respect_annotation_metadata( object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write" ), "force_single_step": force_single_step, "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormat, direction="write" ), "safety_mode": safety_mode, "stream": False, }, headers={ "content-type": "application/json", "Accepts": str(accepts) if accepts is not None else None, }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( NonStreamedChatResponse, construct_type( type_=NonStreamedChatResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) @contextlib.asynccontextmanager async def generate_stream( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateStreamRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]] """ async with self._client_wrapper.httpx_client.stream( "v1/generate", method="POST", json={ "prompt": prompt, "model": model, "num_generations": num_generations, "max_tokens": max_tokens, "truncate": truncate, "temperature": temperature, "seed": seed, "preset": preset, "end_sequences": end_sequences, "stop_sequences": stop_sequences, "k": k, "p": p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "return_likelihoods": return_likelihoods, "raw_prompting": raw_prompting, "stream": True, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) as _response: async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]: try: if 200 <= _response.status_code < 300: async def _iter(): async for _text in _response.aiter_lines(): try: if len(_text) == 0: continue yield typing.cast( GenerateStreamedResponse, construct_type( type_=GenerateStreamedResponse, # type: ignore object_=json.loads(_text), ), ) except Exception: pass return return AsyncHttpResponse(response=_response, data=_iter()) await _response.aread() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield await _stream() async def generate( self, *, prompt: str, model: typing.Optional[str] = OMIT, num_generations: typing.Optional[int] = OMIT, max_tokens: typing.Optional[int] = OMIT, truncate: typing.Optional[GenerateRequestTruncate] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, preset: typing.Optional[str] = OMIT, end_sequences: typing.Optional[typing.Sequence[str]] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT, raw_prompting: typing.Optional[bool] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[Generation]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates realistic text conditioned on a given input. Parameters ---------- prompt : str The input text that serves as the starting point for generating the response. Note: The prompt will be pre-processed and modified before reaching the model. model : typing.Optional[str] The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID. num_generations : typing.Optional[int] The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`. max_tokens : typing.Optional[int] The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details. Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt. truncate : typing.Optional[GenerateRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. temperature : typing.Optional[float] A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. Defaults to `0.75`, min value of `0.0`, max value of `5.0`. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments preset : typing.Optional[str] Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate). When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. end_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. stop_sequences : typing.Optional[typing.Sequence[str]] The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. k : typing.Optional[int] Ensures only the top `k` most likely tokens are considered for generation at each step. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. frequency_penalty : typing.Optional[float] Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods] One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. If `GENERATION` is selected, the token likelihoods will only be provided for generated text. WARNING: `ALL` is deprecated, and will be removed in a future release. raw_prompting : typing.Optional[bool] When enabled, the user's prompt will be sent to the model without any pre-processing. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[Generation] """ _response = await self._client_wrapper.httpx_client.request( "v1/generate", method="POST", json={ "prompt": prompt, "model": model, "num_generations": num_generations, "max_tokens": max_tokens, "truncate": truncate, "temperature": temperature, "seed": seed, "preset": preset, "end_sequences": end_sequences, "stop_sequences": stop_sequences, "k": k, "p": p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "return_likelihoods": return_likelihoods, "raw_prompting": raw_prompting, "stream": False, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( Generation, construct_type( type_=Generation, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[EmbedResponse]: """ This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents. Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Images are only supported with Embed v3.0 and newer models. model : typing.Optional[str] ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : typing.Optional[EmbedInputType] embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[EmbedResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/embed", method="POST", json={ "texts": texts, "images": images, "model": model, "input_type": input_type, "embedding_types": embedding_types, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedResponse, construct_type( type_=EmbedResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def rerank( self, *, query: str, documents: typing.Sequence[RerankRequestDocumentsItem], model: typing.Optional[str] = OMIT, top_n: typing.Optional[int] = OMIT, rank_fields: typing.Optional[typing.Sequence[str]] = OMIT, return_documents: typing.Optional[bool] = OMIT, max_chunks_per_doc: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[RerankResponse]: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- query : str The search query documents : typing.Sequence[RerankRequestDocumentsItem] A list of document objects or strings to rerank. If a document is provided the text fields is required and all other fields will be preserved in the response. The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. We recommend a maximum of 1,000 documents for optimal endpoint performance. model : typing.Optional[str] The identifier of the model to use, eg `rerank-v3.5`. top_n : typing.Optional[int] The number of most relevant documents or indices to return, defaults to the length of the documents rank_fields : typing.Optional[typing.Sequence[str]] If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking. return_documents : typing.Optional[bool] - If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request. - If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request. max_chunks_per_doc : typing.Optional[int] The maximum number of chunks to produce internally from a document request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[RerankResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/rerank", method="POST", json={ "model": model, "query": query, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[RerankRequestDocumentsItem], direction="write" ), "top_n": top_n, "rank_fields": rank_fields, "return_documents": return_documents, "max_chunks_per_doc": max_chunks_per_doc, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( RerankResponse, construct_type( type_=RerankResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def classify( self, *, inputs: typing.Sequence[str], examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT, model: typing.Optional[str] = OMIT, preset: typing.Optional[str] = OMIT, truncate: typing.Optional[ClassifyRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[ClassifyResponse]: """ This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. Parameters ---------- inputs : typing.Sequence[str] A list of up to 96 texts to be classified. Each one must be a non-empty string. There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models). Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts. examples : typing.Optional[typing.Sequence[ClassifyExample]] An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. model : typing.Optional[str] ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model preset : typing.Optional[str] The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters. truncate : typing.Optional[ClassifyRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[ClassifyResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/classify", method="POST", json={ "inputs": inputs, "examples": convert_and_respect_annotation_metadata( object_=examples, annotation=typing.Sequence[ClassifyExample], direction="write" ), "model": model, "preset": preset, "truncate": truncate, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( ClassifyResponse, construct_type( type_=ClassifyResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def summarize( self, *, text: str, length: typing.Optional[SummarizeRequestLength] = OMIT, format: typing.Optional[SummarizeRequestFormat] = OMIT, model: typing.Optional[str] = OMIT, extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT, temperature: typing.Optional[float] = OMIT, additional_command: typing.Optional[str] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[SummarizeResponse]: """ This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. Generates a summary in English for a given text. Parameters ---------- text : str The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English. length : typing.Optional[SummarizeRequestLength] One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. format : typing.Optional[SummarizeRequestFormat] One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. model : typing.Optional[str] The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better. extractiveness : typing.Optional[SummarizeRequestExtractiveness] One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. temperature : typing.Optional[float] Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1. additional_command : typing.Optional[str] A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda" request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[SummarizeResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/summarize", method="POST", json={ "text": text, "length": length, "format": format, "model": model, "extractiveness": extractiveness, "temperature": temperature, "additional_command": additional_command, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( SummarizeResponse, construct_type( type_=SummarizeResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[TokenizeResponse]: """ This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- text : str The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters. model : str The input will be tokenized by the tokenizer that is used by this model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[TokenizeResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/tokenize", method="POST", json={ "text": text, "model": model, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( TokenizeResponse, construct_type( type_=TokenizeResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[DetokenizeResponse]: """ This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page. Parameters ---------- tokens : typing.Sequence[int] The list of tokens to be detokenized. model : str An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[DetokenizeResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/detokenize", method="POST", json={ "tokens": tokens, "model": model, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( DetokenizeResponse, construct_type( type_=DetokenizeResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def check_api_key( self, *, request_options: typing.Optional[RequestOptions] = None ) -> AsyncHttpResponse[CheckApiKeyResponse]: """ Checks that the api key in the Authorization header is valid and active Parameters ---------- request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[CheckApiKeyResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v1/check-api-key", method="POST", request_options=request_options, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( CheckApiKeyResponse, construct_type( type_=CheckApiKeyResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/sagemaker_client.py ================================================ import typing from .aws_client import AwsClient, AwsClientV2 from .manually_maintained.cohere_aws.client import Client from .manually_maintained.cohere_aws.mode import Mode class SagemakerClient(AwsClient): sagemaker_finetuning: Client def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, ): AwsClient.__init__( self, service="sagemaker", aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, timeout=timeout, ) try: self.sagemaker_finetuning = Client(aws_region=aws_region) except Exception: pass class SagemakerClientV2(AwsClientV2): sagemaker_finetuning: Client def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, ): AwsClientV2.__init__( self, service="sagemaker", aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, timeout=timeout, ) try: self.sagemaker_finetuning = Client(aws_region=aws_region) except Exception: pass ================================================ FILE: src/cohere/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .api_meta import ApiMeta from .api_meta_api_version import ApiMetaApiVersion from .api_meta_billed_units import ApiMetaBilledUnits from .api_meta_tokens import ApiMetaTokens from .assistant_message import AssistantMessage from .assistant_message_response import AssistantMessageResponse from .assistant_message_response_content_item import ( AssistantMessageResponseContentItem, TextAssistantMessageResponseContentItem, ThinkingAssistantMessageResponseContentItem, ) from .assistant_message_v2content import AssistantMessageV2Content from .assistant_message_v2content_one_item import ( AssistantMessageV2ContentOneItem, TextAssistantMessageV2ContentOneItem, ThinkingAssistantMessageV2ContentOneItem, ) from .auth_token_type import AuthTokenType from .chat_citation import ChatCitation from .chat_citation_generation_event import ChatCitationGenerationEvent from .chat_citation_type import ChatCitationType from .chat_connector import ChatConnector from .chat_content_delta_event import ChatContentDeltaEvent from .chat_content_delta_event_delta import ChatContentDeltaEventDelta from .chat_content_delta_event_delta_message import ChatContentDeltaEventDeltaMessage from .chat_content_delta_event_delta_message_content import ChatContentDeltaEventDeltaMessageContent from .chat_content_end_event import ChatContentEndEvent from .chat_content_start_event import ChatContentStartEvent from .chat_content_start_event_delta import ChatContentStartEventDelta from .chat_content_start_event_delta_message import ChatContentStartEventDeltaMessage from .chat_content_start_event_delta_message_content import ChatContentStartEventDeltaMessageContent from .chat_content_start_event_delta_message_content_type import ChatContentStartEventDeltaMessageContentType from .chat_data_metrics import ChatDataMetrics from .chat_debug_event import ChatDebugEvent from .chat_document import ChatDocument from .chat_document_source import ChatDocumentSource from .chat_finish_reason import ChatFinishReason from .chat_message import ChatMessage from .chat_message_end_event import ChatMessageEndEvent from .chat_message_end_event_delta import ChatMessageEndEventDelta from .chat_message_start_event import ChatMessageStartEvent from .chat_message_start_event_delta import ChatMessageStartEventDelta from .chat_message_start_event_delta_message import ChatMessageStartEventDeltaMessage from .chat_message_v2 import ( AssistantChatMessageV2, ChatMessageV2, SystemChatMessageV2, ToolChatMessageV2, UserChatMessageV2, ) from .chat_messages import ChatMessages from .chat_request_citation_quality import ChatRequestCitationQuality from .chat_request_prompt_truncation import ChatRequestPromptTruncation from .chat_request_safety_mode import ChatRequestSafetyMode from .chat_search_queries_generation_event import ChatSearchQueriesGenerationEvent from .chat_search_query import ChatSearchQuery from .chat_search_result import ChatSearchResult from .chat_search_result_connector import ChatSearchResultConnector from .chat_search_results_event import ChatSearchResultsEvent from .chat_stream_end_event import ChatStreamEndEvent from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason from .chat_stream_event import ChatStreamEvent from .chat_stream_event_type import ChatStreamEventType from .chat_stream_request_citation_quality import ChatStreamRequestCitationQuality from .chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation from .chat_stream_request_safety_mode import ChatStreamRequestSafetyMode from .chat_stream_start_event import ChatStreamStartEvent from .chat_text_content import ChatTextContent from .chat_text_generation_event import ChatTextGenerationEvent from .chat_text_response_format import ChatTextResponseFormat from .chat_text_response_format_v2 import ChatTextResponseFormatV2 from .chat_thinking_content import ChatThinkingContent from .chat_tool_call_delta_event import ChatToolCallDeltaEvent from .chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta from .chat_tool_call_delta_event_delta_message import ChatToolCallDeltaEventDeltaMessage from .chat_tool_call_delta_event_delta_message_tool_calls import ChatToolCallDeltaEventDeltaMessageToolCalls from .chat_tool_call_delta_event_delta_message_tool_calls_function import ( ChatToolCallDeltaEventDeltaMessageToolCallsFunction, ) from .chat_tool_call_end_event import ChatToolCallEndEvent from .chat_tool_call_start_event import ChatToolCallStartEvent from .chat_tool_call_start_event_delta import ChatToolCallStartEventDelta from .chat_tool_call_start_event_delta_message import ChatToolCallStartEventDeltaMessage from .chat_tool_calls_chunk_event import ChatToolCallsChunkEvent from .chat_tool_calls_generation_event import ChatToolCallsGenerationEvent from .chat_tool_message import ChatToolMessage from .chat_tool_plan_delta_event import ChatToolPlanDeltaEvent from .chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta from .chat_tool_plan_delta_event_delta_message import ChatToolPlanDeltaEventDeltaMessage from .chat_tool_source import ChatToolSource from .check_api_key_response import CheckApiKeyResponse from .citation import Citation from .citation_end_event import CitationEndEvent from .citation_options import CitationOptions from .citation_options_mode import CitationOptionsMode from .citation_start_event import CitationStartEvent from .citation_start_event_delta import CitationStartEventDelta from .citation_start_event_delta_message import CitationStartEventDeltaMessage from .citation_type import CitationType from .classify_data_metrics import ClassifyDataMetrics from .classify_example import ClassifyExample from .classify_request_truncate import ClassifyRequestTruncate from .classify_response import ClassifyResponse from .classify_response_classifications_item import ClassifyResponseClassificationsItem from .classify_response_classifications_item_classification_type import ( ClassifyResponseClassificationsItemClassificationType, ) from .classify_response_classifications_item_labels_value import ClassifyResponseClassificationsItemLabelsValue from .compatible_endpoint import CompatibleEndpoint from .connector import Connector from .connector_auth_status import ConnectorAuthStatus from .connector_o_auth import ConnectorOAuth from .content import Content, ImageUrlContent, TextContent from .create_connector_o_auth import CreateConnectorOAuth from .create_connector_response import CreateConnectorResponse from .create_connector_service_auth import CreateConnectorServiceAuth from .create_embed_job_response import CreateEmbedJobResponse from .dataset import Dataset from .dataset_part import DatasetPart from .dataset_type import DatasetType from .dataset_validation_status import DatasetValidationStatus from .delete_connector_response import DeleteConnectorResponse from .detokenize_response import DetokenizeResponse from .document import Document from .document_content import DocumentContent from .embed_by_type_response import EmbedByTypeResponse from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings from .embed_by_type_response_response_type import EmbedByTypeResponseResponseType from .embed_content import EmbedContent, ImageUrlEmbedContent, TextEmbedContent from .embed_floats_response import EmbedFloatsResponse from .embed_image import EmbedImage from .embed_image_url import EmbedImageUrl from .embed_input import EmbedInput from .embed_input_type import EmbedInputType from .embed_job import EmbedJob from .embed_job_status import EmbedJobStatus from .embed_job_truncate import EmbedJobTruncate from .embed_request_truncate import EmbedRequestTruncate from .embed_response import EmbedResponse, EmbeddingsByTypeEmbedResponse, EmbeddingsFloatsEmbedResponse from .embed_text import EmbedText from .embedding_type import EmbeddingType from .finetune_dataset_metrics import FinetuneDatasetMetrics from .finish_reason import FinishReason from .generate_request_return_likelihoods import GenerateRequestReturnLikelihoods from .generate_request_truncate import GenerateRequestTruncate from .generate_stream_end import GenerateStreamEnd from .generate_stream_end_response import GenerateStreamEndResponse from .generate_stream_error import GenerateStreamError from .generate_stream_event import GenerateStreamEvent from .generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods from .generate_stream_request_truncate import GenerateStreamRequestTruncate from .generate_stream_text import GenerateStreamText from .generate_streamed_response import ( GenerateStreamedResponse, StreamEndGenerateStreamedResponse, StreamErrorGenerateStreamedResponse, TextGenerationGenerateStreamedResponse, ) from .generation import Generation from .get_connector_response import GetConnectorResponse from .get_model_response import GetModelResponse from .get_model_response_sampling_defaults import GetModelResponseSamplingDefaults from .image import Image from .image_content import ImageContent from .image_url import ImageUrl from .image_url_detail import ImageUrlDetail from .json_response_format import JsonResponseFormat from .json_response_format_v2 import JsonResponseFormatV2 from .label_metric import LabelMetric from .list_connectors_response import ListConnectorsResponse from .list_embed_job_response import ListEmbedJobResponse from .list_models_response import ListModelsResponse from .logprob_item import LogprobItem from .message import ChatbotMessage, Message, SystemMessage, ToolMessage, UserMessage from .metrics import Metrics from .non_streamed_chat_response import NonStreamedChatResponse from .o_auth_authorize_response import OAuthAuthorizeResponse from .parse_info import ParseInfo from .rerank_document import RerankDocument from .rerank_request_documents_item import RerankRequestDocumentsItem from .rerank_response import RerankResponse from .rerank_response_results_item import RerankResponseResultsItem from .rerank_response_results_item_document import RerankResponseResultsItemDocument from .reranker_data_metrics import RerankerDataMetrics from .response_format import JsonObjectResponseFormat, ResponseFormat, TextResponseFormat from .response_format_v2 import JsonObjectResponseFormatV2, ResponseFormatV2, TextResponseFormatV2 from .single_generation import SingleGeneration from .single_generation_in_stream import SingleGenerationInStream from .single_generation_token_likelihoods_item import SingleGenerationTokenLikelihoodsItem from .source import DocumentSource, Source, ToolSource from .streamed_chat_response import ( CitationGenerationStreamedChatResponse, DebugStreamedChatResponse, SearchQueriesGenerationStreamedChatResponse, SearchResultsStreamedChatResponse, StreamEndStreamedChatResponse, StreamStartStreamedChatResponse, StreamedChatResponse, TextGenerationStreamedChatResponse, ToolCallsChunkStreamedChatResponse, ToolCallsGenerationStreamedChatResponse, ) from .summarize_request_extractiveness import SummarizeRequestExtractiveness from .summarize_request_format import SummarizeRequestFormat from .summarize_request_length import SummarizeRequestLength from .summarize_response import SummarizeResponse from .system_message_v2 import SystemMessageV2 from .system_message_v2content import SystemMessageV2Content from .system_message_v2content_one_item import SystemMessageV2ContentOneItem, TextSystemMessageV2ContentOneItem from .thinking import Thinking from .thinking_type import ThinkingType from .tokenize_response import TokenizeResponse from .tool import Tool from .tool_call import ToolCall from .tool_call_delta import ToolCallDelta from .tool_call_v2 import ToolCallV2 from .tool_call_v2function import ToolCallV2Function from .tool_content import DocumentToolContent, TextToolContent, ToolContent from .tool_message_v2 import ToolMessageV2 from .tool_message_v2content import ToolMessageV2Content from .tool_parameter_definitions_value import ToolParameterDefinitionsValue from .tool_result import ToolResult from .tool_v2 import ToolV2 from .tool_v2function import ToolV2Function from .update_connector_response import UpdateConnectorResponse from .usage import Usage from .usage_billed_units import UsageBilledUnits from .usage_tokens import UsageTokens from .user_message_v2 import UserMessageV2 from .user_message_v2content import UserMessageV2Content _dynamic_imports: typing.Dict[str, str] = { "ApiMeta": ".api_meta", "ApiMetaApiVersion": ".api_meta_api_version", "ApiMetaBilledUnits": ".api_meta_billed_units", "ApiMetaTokens": ".api_meta_tokens", "AssistantChatMessageV2": ".chat_message_v2", "AssistantMessage": ".assistant_message", "AssistantMessageResponse": ".assistant_message_response", "AssistantMessageResponseContentItem": ".assistant_message_response_content_item", "AssistantMessageV2Content": ".assistant_message_v2content", "AssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item", "AuthTokenType": ".auth_token_type", "ChatCitation": ".chat_citation", "ChatCitationGenerationEvent": ".chat_citation_generation_event", "ChatCitationType": ".chat_citation_type", "ChatConnector": ".chat_connector", "ChatContentDeltaEvent": ".chat_content_delta_event", "ChatContentDeltaEventDelta": ".chat_content_delta_event_delta", "ChatContentDeltaEventDeltaMessage": ".chat_content_delta_event_delta_message", "ChatContentDeltaEventDeltaMessageContent": ".chat_content_delta_event_delta_message_content", "ChatContentEndEvent": ".chat_content_end_event", "ChatContentStartEvent": ".chat_content_start_event", "ChatContentStartEventDelta": ".chat_content_start_event_delta", "ChatContentStartEventDeltaMessage": ".chat_content_start_event_delta_message", "ChatContentStartEventDeltaMessageContent": ".chat_content_start_event_delta_message_content", "ChatContentStartEventDeltaMessageContentType": ".chat_content_start_event_delta_message_content_type", "ChatDataMetrics": ".chat_data_metrics", "ChatDebugEvent": ".chat_debug_event", "ChatDocument": ".chat_document", "ChatDocumentSource": ".chat_document_source", "ChatFinishReason": ".chat_finish_reason", "ChatMessage": ".chat_message", "ChatMessageEndEvent": ".chat_message_end_event", "ChatMessageEndEventDelta": ".chat_message_end_event_delta", "ChatMessageStartEvent": ".chat_message_start_event", "ChatMessageStartEventDelta": ".chat_message_start_event_delta", "ChatMessageStartEventDeltaMessage": ".chat_message_start_event_delta_message", "ChatMessageV2": ".chat_message_v2", "ChatMessages": ".chat_messages", "ChatRequestCitationQuality": ".chat_request_citation_quality", "ChatRequestPromptTruncation": ".chat_request_prompt_truncation", "ChatRequestSafetyMode": ".chat_request_safety_mode", "ChatSearchQueriesGenerationEvent": ".chat_search_queries_generation_event", "ChatSearchQuery": ".chat_search_query", "ChatSearchResult": ".chat_search_result", "ChatSearchResultConnector": ".chat_search_result_connector", "ChatSearchResultsEvent": ".chat_search_results_event", "ChatStreamEndEvent": ".chat_stream_end_event", "ChatStreamEndEventFinishReason": ".chat_stream_end_event_finish_reason", "ChatStreamEvent": ".chat_stream_event", "ChatStreamEventType": ".chat_stream_event_type", "ChatStreamRequestCitationQuality": ".chat_stream_request_citation_quality", "ChatStreamRequestPromptTruncation": ".chat_stream_request_prompt_truncation", "ChatStreamRequestSafetyMode": ".chat_stream_request_safety_mode", "ChatStreamStartEvent": ".chat_stream_start_event", "ChatTextContent": ".chat_text_content", "ChatTextGenerationEvent": ".chat_text_generation_event", "ChatTextResponseFormat": ".chat_text_response_format", "ChatTextResponseFormatV2": ".chat_text_response_format_v2", "ChatThinkingContent": ".chat_thinking_content", "ChatToolCallDeltaEvent": ".chat_tool_call_delta_event", "ChatToolCallDeltaEventDelta": ".chat_tool_call_delta_event_delta", "ChatToolCallDeltaEventDeltaMessage": ".chat_tool_call_delta_event_delta_message", "ChatToolCallDeltaEventDeltaMessageToolCalls": ".chat_tool_call_delta_event_delta_message_tool_calls", "ChatToolCallDeltaEventDeltaMessageToolCallsFunction": ".chat_tool_call_delta_event_delta_message_tool_calls_function", "ChatToolCallEndEvent": ".chat_tool_call_end_event", "ChatToolCallStartEvent": ".chat_tool_call_start_event", "ChatToolCallStartEventDelta": ".chat_tool_call_start_event_delta", "ChatToolCallStartEventDeltaMessage": ".chat_tool_call_start_event_delta_message", "ChatToolCallsChunkEvent": ".chat_tool_calls_chunk_event", "ChatToolCallsGenerationEvent": ".chat_tool_calls_generation_event", "ChatToolMessage": ".chat_tool_message", "ChatToolPlanDeltaEvent": ".chat_tool_plan_delta_event", "ChatToolPlanDeltaEventDelta": ".chat_tool_plan_delta_event_delta", "ChatToolPlanDeltaEventDeltaMessage": ".chat_tool_plan_delta_event_delta_message", "ChatToolSource": ".chat_tool_source", "ChatbotMessage": ".message", "CheckApiKeyResponse": ".check_api_key_response", "Citation": ".citation", "CitationEndEvent": ".citation_end_event", "CitationGenerationStreamedChatResponse": ".streamed_chat_response", "CitationOptions": ".citation_options", "CitationOptionsMode": ".citation_options_mode", "CitationStartEvent": ".citation_start_event", "CitationStartEventDelta": ".citation_start_event_delta", "CitationStartEventDeltaMessage": ".citation_start_event_delta_message", "CitationType": ".citation_type", "ClassifyDataMetrics": ".classify_data_metrics", "ClassifyExample": ".classify_example", "ClassifyRequestTruncate": ".classify_request_truncate", "ClassifyResponse": ".classify_response", "ClassifyResponseClassificationsItem": ".classify_response_classifications_item", "ClassifyResponseClassificationsItemClassificationType": ".classify_response_classifications_item_classification_type", "ClassifyResponseClassificationsItemLabelsValue": ".classify_response_classifications_item_labels_value", "CompatibleEndpoint": ".compatible_endpoint", "Connector": ".connector", "ConnectorAuthStatus": ".connector_auth_status", "ConnectorOAuth": ".connector_o_auth", "Content": ".content", "CreateConnectorOAuth": ".create_connector_o_auth", "CreateConnectorResponse": ".create_connector_response", "CreateConnectorServiceAuth": ".create_connector_service_auth", "CreateEmbedJobResponse": ".create_embed_job_response", "Dataset": ".dataset", "DatasetPart": ".dataset_part", "DatasetType": ".dataset_type", "DatasetValidationStatus": ".dataset_validation_status", "DebugStreamedChatResponse": ".streamed_chat_response", "DeleteConnectorResponse": ".delete_connector_response", "DetokenizeResponse": ".detokenize_response", "Document": ".document", "DocumentContent": ".document_content", "DocumentSource": ".source", "DocumentToolContent": ".tool_content", "EmbedByTypeResponse": ".embed_by_type_response", "EmbedByTypeResponseEmbeddings": ".embed_by_type_response_embeddings", "EmbedByTypeResponseResponseType": ".embed_by_type_response_response_type", "EmbedContent": ".embed_content", "EmbedFloatsResponse": ".embed_floats_response", "EmbedImage": ".embed_image", "EmbedImageUrl": ".embed_image_url", "EmbedInput": ".embed_input", "EmbedInputType": ".embed_input_type", "EmbedJob": ".embed_job", "EmbedJobStatus": ".embed_job_status", "EmbedJobTruncate": ".embed_job_truncate", "EmbedRequestTruncate": ".embed_request_truncate", "EmbedResponse": ".embed_response", "EmbedText": ".embed_text", "EmbeddingType": ".embedding_type", "EmbeddingsByTypeEmbedResponse": ".embed_response", "EmbeddingsFloatsEmbedResponse": ".embed_response", "FinetuneDatasetMetrics": ".finetune_dataset_metrics", "FinishReason": ".finish_reason", "GenerateRequestReturnLikelihoods": ".generate_request_return_likelihoods", "GenerateRequestTruncate": ".generate_request_truncate", "GenerateStreamEnd": ".generate_stream_end", "GenerateStreamEndResponse": ".generate_stream_end_response", "GenerateStreamError": ".generate_stream_error", "GenerateStreamEvent": ".generate_stream_event", "GenerateStreamRequestReturnLikelihoods": ".generate_stream_request_return_likelihoods", "GenerateStreamRequestTruncate": ".generate_stream_request_truncate", "GenerateStreamText": ".generate_stream_text", "GenerateStreamedResponse": ".generate_streamed_response", "Generation": ".generation", "GetConnectorResponse": ".get_connector_response", "GetModelResponse": ".get_model_response", "GetModelResponseSamplingDefaults": ".get_model_response_sampling_defaults", "Image": ".image", "ImageContent": ".image_content", "ImageUrl": ".image_url", "ImageUrlContent": ".content", "ImageUrlDetail": ".image_url_detail", "ImageUrlEmbedContent": ".embed_content", "JsonObjectResponseFormat": ".response_format", "JsonObjectResponseFormatV2": ".response_format_v2", "JsonResponseFormat": ".json_response_format", "JsonResponseFormatV2": ".json_response_format_v2", "LabelMetric": ".label_metric", "ListConnectorsResponse": ".list_connectors_response", "ListEmbedJobResponse": ".list_embed_job_response", "ListModelsResponse": ".list_models_response", "LogprobItem": ".logprob_item", "Message": ".message", "Metrics": ".metrics", "NonStreamedChatResponse": ".non_streamed_chat_response", "OAuthAuthorizeResponse": ".o_auth_authorize_response", "ParseInfo": ".parse_info", "RerankDocument": ".rerank_document", "RerankRequestDocumentsItem": ".rerank_request_documents_item", "RerankResponse": ".rerank_response", "RerankResponseResultsItem": ".rerank_response_results_item", "RerankResponseResultsItemDocument": ".rerank_response_results_item_document", "RerankerDataMetrics": ".reranker_data_metrics", "ResponseFormat": ".response_format", "ResponseFormatV2": ".response_format_v2", "SearchQueriesGenerationStreamedChatResponse": ".streamed_chat_response", "SearchResultsStreamedChatResponse": ".streamed_chat_response", "SingleGeneration": ".single_generation", "SingleGenerationInStream": ".single_generation_in_stream", "SingleGenerationTokenLikelihoodsItem": ".single_generation_token_likelihoods_item", "Source": ".source", "StreamEndGenerateStreamedResponse": ".generate_streamed_response", "StreamEndStreamedChatResponse": ".streamed_chat_response", "StreamErrorGenerateStreamedResponse": ".generate_streamed_response", "StreamStartStreamedChatResponse": ".streamed_chat_response", "StreamedChatResponse": ".streamed_chat_response", "SummarizeRequestExtractiveness": ".summarize_request_extractiveness", "SummarizeRequestFormat": ".summarize_request_format", "SummarizeRequestLength": ".summarize_request_length", "SummarizeResponse": ".summarize_response", "SystemChatMessageV2": ".chat_message_v2", "SystemMessage": ".message", "SystemMessageV2": ".system_message_v2", "SystemMessageV2Content": ".system_message_v2content", "SystemMessageV2ContentOneItem": ".system_message_v2content_one_item", "TextAssistantMessageResponseContentItem": ".assistant_message_response_content_item", "TextAssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item", "TextContent": ".content", "TextEmbedContent": ".embed_content", "TextGenerationGenerateStreamedResponse": ".generate_streamed_response", "TextGenerationStreamedChatResponse": ".streamed_chat_response", "TextResponseFormat": ".response_format", "TextResponseFormatV2": ".response_format_v2", "TextSystemMessageV2ContentOneItem": ".system_message_v2content_one_item", "TextToolContent": ".tool_content", "Thinking": ".thinking", "ThinkingAssistantMessageResponseContentItem": ".assistant_message_response_content_item", "ThinkingAssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item", "ThinkingType": ".thinking_type", "TokenizeResponse": ".tokenize_response", "Tool": ".tool", "ToolCall": ".tool_call", "ToolCallDelta": ".tool_call_delta", "ToolCallV2": ".tool_call_v2", "ToolCallV2Function": ".tool_call_v2function", "ToolCallsChunkStreamedChatResponse": ".streamed_chat_response", "ToolCallsGenerationStreamedChatResponse": ".streamed_chat_response", "ToolChatMessageV2": ".chat_message_v2", "ToolContent": ".tool_content", "ToolMessage": ".message", "ToolMessageV2": ".tool_message_v2", "ToolMessageV2Content": ".tool_message_v2content", "ToolParameterDefinitionsValue": ".tool_parameter_definitions_value", "ToolResult": ".tool_result", "ToolSource": ".source", "ToolV2": ".tool_v2", "ToolV2Function": ".tool_v2function", "UpdateConnectorResponse": ".update_connector_response", "Usage": ".usage", "UsageBilledUnits": ".usage_billed_units", "UsageTokens": ".usage_tokens", "UserChatMessageV2": ".chat_message_v2", "UserMessage": ".message", "UserMessageV2": ".user_message_v2", "UserMessageV2Content": ".user_message_v2content", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "ApiMeta", "ApiMetaApiVersion", "ApiMetaBilledUnits", "ApiMetaTokens", "AssistantChatMessageV2", "AssistantMessage", "AssistantMessageResponse", "AssistantMessageResponseContentItem", "AssistantMessageV2Content", "AssistantMessageV2ContentOneItem", "AuthTokenType", "ChatCitation", "ChatCitationGenerationEvent", "ChatCitationType", "ChatConnector", "ChatContentDeltaEvent", "ChatContentDeltaEventDelta", "ChatContentDeltaEventDeltaMessage", "ChatContentDeltaEventDeltaMessageContent", "ChatContentEndEvent", "ChatContentStartEvent", "ChatContentStartEventDelta", "ChatContentStartEventDeltaMessage", "ChatContentStartEventDeltaMessageContent", "ChatContentStartEventDeltaMessageContentType", "ChatDataMetrics", "ChatDebugEvent", "ChatDocument", "ChatDocumentSource", "ChatFinishReason", "ChatMessage", "ChatMessageEndEvent", "ChatMessageEndEventDelta", "ChatMessageStartEvent", "ChatMessageStartEventDelta", "ChatMessageStartEventDeltaMessage", "ChatMessageV2", "ChatMessages", "ChatRequestCitationQuality", "ChatRequestPromptTruncation", "ChatRequestSafetyMode", "ChatSearchQueriesGenerationEvent", "ChatSearchQuery", "ChatSearchResult", "ChatSearchResultConnector", "ChatSearchResultsEvent", "ChatStreamEndEvent", "ChatStreamEndEventFinishReason", "ChatStreamEvent", "ChatStreamEventType", "ChatStreamRequestCitationQuality", "ChatStreamRequestPromptTruncation", "ChatStreamRequestSafetyMode", "ChatStreamStartEvent", "ChatTextContent", "ChatTextGenerationEvent", "ChatTextResponseFormat", "ChatTextResponseFormatV2", "ChatThinkingContent", "ChatToolCallDeltaEvent", "ChatToolCallDeltaEventDelta", "ChatToolCallDeltaEventDeltaMessage", "ChatToolCallDeltaEventDeltaMessageToolCalls", "ChatToolCallDeltaEventDeltaMessageToolCallsFunction", "ChatToolCallEndEvent", "ChatToolCallStartEvent", "ChatToolCallStartEventDelta", "ChatToolCallStartEventDeltaMessage", "ChatToolCallsChunkEvent", "ChatToolCallsGenerationEvent", "ChatToolMessage", "ChatToolPlanDeltaEvent", "ChatToolPlanDeltaEventDelta", "ChatToolPlanDeltaEventDeltaMessage", "ChatToolSource", "ChatbotMessage", "CheckApiKeyResponse", "Citation", "CitationEndEvent", "CitationGenerationStreamedChatResponse", "CitationOptions", "CitationOptionsMode", "CitationStartEvent", "CitationStartEventDelta", "CitationStartEventDeltaMessage", "CitationType", "ClassifyDataMetrics", "ClassifyExample", "ClassifyRequestTruncate", "ClassifyResponse", "ClassifyResponseClassificationsItem", "ClassifyResponseClassificationsItemClassificationType", "ClassifyResponseClassificationsItemLabelsValue", "CompatibleEndpoint", "Connector", "ConnectorAuthStatus", "ConnectorOAuth", "Content", "CreateConnectorOAuth", "CreateConnectorResponse", "CreateConnectorServiceAuth", "CreateEmbedJobResponse", "Dataset", "DatasetPart", "DatasetType", "DatasetValidationStatus", "DebugStreamedChatResponse", "DeleteConnectorResponse", "DetokenizeResponse", "Document", "DocumentContent", "DocumentSource", "DocumentToolContent", "EmbedByTypeResponse", "EmbedByTypeResponseEmbeddings", "EmbedByTypeResponseResponseType", "EmbedContent", "EmbedFloatsResponse", "EmbedImage", "EmbedImageUrl", "EmbedInput", "EmbedInputType", "EmbedJob", "EmbedJobStatus", "EmbedJobTruncate", "EmbedRequestTruncate", "EmbedResponse", "EmbedText", "EmbeddingType", "EmbeddingsByTypeEmbedResponse", "EmbeddingsFloatsEmbedResponse", "FinetuneDatasetMetrics", "FinishReason", "GenerateRequestReturnLikelihoods", "GenerateRequestTruncate", "GenerateStreamEnd", "GenerateStreamEndResponse", "GenerateStreamError", "GenerateStreamEvent", "GenerateStreamRequestReturnLikelihoods", "GenerateStreamRequestTruncate", "GenerateStreamText", "GenerateStreamedResponse", "Generation", "GetConnectorResponse", "GetModelResponse", "GetModelResponseSamplingDefaults", "Image", "ImageContent", "ImageUrl", "ImageUrlContent", "ImageUrlDetail", "ImageUrlEmbedContent", "JsonObjectResponseFormat", "JsonObjectResponseFormatV2", "JsonResponseFormat", "JsonResponseFormatV2", "LabelMetric", "ListConnectorsResponse", "ListEmbedJobResponse", "ListModelsResponse", "LogprobItem", "Message", "Metrics", "NonStreamedChatResponse", "OAuthAuthorizeResponse", "ParseInfo", "RerankDocument", "RerankRequestDocumentsItem", "RerankResponse", "RerankResponseResultsItem", "RerankResponseResultsItemDocument", "RerankerDataMetrics", "ResponseFormat", "ResponseFormatV2", "SearchQueriesGenerationStreamedChatResponse", "SearchResultsStreamedChatResponse", "SingleGeneration", "SingleGenerationInStream", "SingleGenerationTokenLikelihoodsItem", "Source", "StreamEndGenerateStreamedResponse", "StreamEndStreamedChatResponse", "StreamErrorGenerateStreamedResponse", "StreamStartStreamedChatResponse", "StreamedChatResponse", "SummarizeRequestExtractiveness", "SummarizeRequestFormat", "SummarizeRequestLength", "SummarizeResponse", "SystemChatMessageV2", "SystemMessage", "SystemMessageV2", "SystemMessageV2Content", "SystemMessageV2ContentOneItem", "TextAssistantMessageResponseContentItem", "TextAssistantMessageV2ContentOneItem", "TextContent", "TextEmbedContent", "TextGenerationGenerateStreamedResponse", "TextGenerationStreamedChatResponse", "TextResponseFormat", "TextResponseFormatV2", "TextSystemMessageV2ContentOneItem", "TextToolContent", "Thinking", "ThinkingAssistantMessageResponseContentItem", "ThinkingAssistantMessageV2ContentOneItem", "ThinkingType", "TokenizeResponse", "Tool", "ToolCall", "ToolCallDelta", "ToolCallV2", "ToolCallV2Function", "ToolCallsChunkStreamedChatResponse", "ToolCallsGenerationStreamedChatResponse", "ToolChatMessageV2", "ToolContent", "ToolMessage", "ToolMessageV2", "ToolMessageV2Content", "ToolParameterDefinitionsValue", "ToolResult", "ToolSource", "ToolV2", "ToolV2Function", "UpdateConnectorResponse", "Usage", "UsageBilledUnits", "UsageTokens", "UserChatMessageV2", "UserMessage", "UserMessageV2", "UserMessageV2Content", ] ================================================ FILE: src/cohere/types/api_meta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta_api_version import ApiMetaApiVersion from .api_meta_billed_units import ApiMetaBilledUnits from .api_meta_tokens import ApiMetaTokens class ApiMeta(UncheckedBaseModel): api_version: typing.Optional[ApiMetaApiVersion] = None billed_units: typing.Optional[ApiMetaBilledUnits] = None tokens: typing.Optional[ApiMetaTokens] = None cached_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of prompt tokens that hit the inference cache. """ warnings: typing.Optional[typing.List[str]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/api_meta_api_version.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ApiMetaApiVersion(UncheckedBaseModel): version: str is_deprecated: typing.Optional[bool] = None is_experimental: typing.Optional[bool] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/api_meta_billed_units.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ApiMetaBilledUnits(UncheckedBaseModel): images: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed images. """ input_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed input tokens. """ image_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed image tokens. """ output_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed output tokens. """ search_units: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed search units. """ classifications: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed classifications units. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/api_meta_tokens.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ApiMetaTokens(UncheckedBaseModel): input_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of tokens used as input to the model. """ output_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of tokens produced by the model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/assistant_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .assistant_message_v2content import AssistantMessageV2Content from .citation import Citation from .tool_call_v2 import ToolCallV2 class AssistantMessage(UncheckedBaseModel): """ A message from the assistant role can contain text and tool call information. """ tool_calls: typing.Optional[typing.List[ToolCallV2]] = None tool_plan: typing.Optional[str] = pydantic.Field(default=None) """ A chain-of-thought style reflection and plan that the model generates when working with Tools. """ content: typing.Optional[AssistantMessageV2Content] = None citations: typing.Optional[typing.List[Citation]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/assistant_message_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .assistant_message_response_content_item import AssistantMessageResponseContentItem from .citation import Citation from .tool_call_v2 import ToolCallV2 class AssistantMessageResponse(UncheckedBaseModel): """ A message from the assistant role can contain text and tool call information. """ role: typing.Literal["assistant"] = "assistant" tool_calls: typing.Optional[typing.List[ToolCallV2]] = None tool_plan: typing.Optional[str] = pydantic.Field(default=None) """ A chain-of-thought style reflection and plan that the model generates when working with Tools. """ content: typing.Optional[typing.List[AssistantMessageResponseContentItem]] = None citations: typing.Optional[typing.List[Citation]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/assistant_message_response_content_item.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata class TextAssistantMessageResponseContentItem(UncheckedBaseModel): type: typing.Literal["text"] = "text" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ThinkingAssistantMessageResponseContentItem(UncheckedBaseModel): type: typing.Literal["thinking"] = "thinking" thinking: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow AssistantMessageResponseContentItem = typing_extensions.Annotated[ typing.Union[TextAssistantMessageResponseContentItem, ThinkingAssistantMessageResponseContentItem], UnionMetadata(discriminant="type"), ] ================================================ FILE: src/cohere/types/assistant_message_v2content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .assistant_message_v2content_one_item import AssistantMessageV2ContentOneItem AssistantMessageV2Content = typing.Union[str, typing.List[AssistantMessageV2ContentOneItem]] ================================================ FILE: src/cohere/types/assistant_message_v2content_one_item.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata class TextAssistantMessageV2ContentOneItem(UncheckedBaseModel): type: typing.Literal["text"] = "text" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ThinkingAssistantMessageV2ContentOneItem(UncheckedBaseModel): type: typing.Literal["thinking"] = "thinking" thinking: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow AssistantMessageV2ContentOneItem = typing_extensions.Annotated[ typing.Union[TextAssistantMessageV2ContentOneItem, ThinkingAssistantMessageV2ContentOneItem], UnionMetadata(discriminant="type"), ] ================================================ FILE: src/cohere/types/auth_token_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing AuthTokenType = typing.Union[typing.Literal["bearer", "basic", "noscheme"], typing.Any] ================================================ FILE: src/cohere/types/chat_citation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_citation_type import ChatCitationType class ChatCitation(UncheckedBaseModel): """ A section of the generated reply which cites external knowledge. """ start: int = pydantic.Field() """ The index of text that the citation starts at, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have a start value of `7`. This is because the citation starts at `w`, which is the seventh character. """ end: int = pydantic.Field() """ The index of text that the citation ends after, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have an end value of `11`. This is because the citation ends after `d`, which is the eleventh character. """ text: str = pydantic.Field() """ The text of the citation. For example, a generation of `Hello, world!` with a citation of `world` would have a text value of `world`. """ document_ids: typing.List[str] = pydantic.Field() """ Identifiers of documents cited by this section of the generated reply. """ type: typing.Optional[ChatCitationType] = pydantic.Field(default=None) """ The type of citation which indicates what part of the response the citation is for. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_citation_generation_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_citation import ChatCitation from .chat_stream_event import ChatStreamEvent class ChatCitationGenerationEvent(ChatStreamEvent): citations: typing.List[ChatCitation] = pydantic.Field() """ Citations for the generated reply. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_citation_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatCitationType = typing.Union[typing.Literal["TEXT_CONTENT", "PLAN"], typing.Any] ================================================ FILE: src/cohere/types/chat_connector.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatConnector(UncheckedBaseModel): """ The connector used for fetching documents. """ id: str = pydantic.Field() """ The identifier of the connector. """ user_access_token: typing.Optional[str] = pydantic.Field(default=None) """ When specified, this user access token will be passed to the connector in the Authorization header instead of the Cohere generated one. """ continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None) """ Defaults to `false`. When `true`, the request will continue if this connector returned an error. """ options: typing.Optional[typing.Dict[str, typing.Any]] = pydantic.Field(default=None) """ Provides the connector with different settings at request time. The key/value pairs of this object are specific to each connector. For example, the connector `web-search` supports the `site` option, which limits search results to the specified domain. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_delta_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_content_delta_event_delta import ChatContentDeltaEventDelta from .chat_stream_event_type import ChatStreamEventType from .logprob_item import LogprobItem class ChatContentDeltaEvent(ChatStreamEventType): """ A streamed delta event which contains a delta of chat text content. """ index: typing.Optional[int] = None delta: typing.Optional[ChatContentDeltaEventDelta] = None logprobs: typing.Optional[LogprobItem] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_delta_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_content_delta_event_delta_message import ChatContentDeltaEventDeltaMessage class ChatContentDeltaEventDelta(UncheckedBaseModel): message: typing.Optional[ChatContentDeltaEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_delta_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_content_delta_event_delta_message_content import ChatContentDeltaEventDeltaMessageContent class ChatContentDeltaEventDeltaMessage(UncheckedBaseModel): content: typing.Optional[ChatContentDeltaEventDeltaMessageContent] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_delta_event_delta_message_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatContentDeltaEventDeltaMessageContent(UncheckedBaseModel): thinking: typing.Optional[str] = None text: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_end_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType class ChatContentEndEvent(ChatStreamEventType): """ A streamed delta event which signifies that the content block has ended. """ index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_start_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_content_start_event_delta import ChatContentStartEventDelta from .chat_stream_event_type import ChatStreamEventType class ChatContentStartEvent(ChatStreamEventType): """ A streamed delta event which signifies that a new content block has started. """ index: typing.Optional[int] = None delta: typing.Optional[ChatContentStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_start_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_content_start_event_delta_message import ChatContentStartEventDeltaMessage class ChatContentStartEventDelta(UncheckedBaseModel): message: typing.Optional[ChatContentStartEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_start_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_content_start_event_delta_message_content import ChatContentStartEventDeltaMessageContent class ChatContentStartEventDeltaMessage(UncheckedBaseModel): content: typing.Optional[ChatContentStartEventDeltaMessageContent] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_start_event_delta_message_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_content_start_event_delta_message_content_type import ChatContentStartEventDeltaMessageContentType class ChatContentStartEventDeltaMessageContent(UncheckedBaseModel): thinking: typing.Optional[str] = None text: typing.Optional[str] = None type: typing.Optional[ChatContentStartEventDeltaMessageContentType] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_content_start_event_delta_message_content_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatContentStartEventDeltaMessageContentType = typing.Union[typing.Literal["text", "thinking"], typing.Any] ================================================ FILE: src/cohere/types/chat_data_metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatDataMetrics(UncheckedBaseModel): num_train_turns: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all turns of valid train examples. """ num_eval_turns: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all turns of valid eval examples. """ preamble: typing.Optional[str] = pydantic.Field(default=None) """ The preamble of this dataset. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_debug_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event import ChatStreamEvent class ChatDebugEvent(ChatStreamEvent): prompt: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_document.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatDocument = typing.Dict[str, str] """ Relevant information that could be used by the model to generate a more accurate reply. The contents of each document are generally short (under 300 words), and are passed in the form of a dictionary of strings. Some suggested keys are "text", "author", "date". Both the key name and the value will be passed to the model. """ ================================================ FILE: src/cohere/types/chat_document_source.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatDocumentSource(UncheckedBaseModel): """ A document source object containing the unique identifier of the document and the document itself. """ id: typing.Optional[str] = pydantic.Field(default=None) """ The unique identifier of the document """ document: typing.Optional[typing.Dict[str, typing.Any]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_finish_reason.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatFinishReason = typing.Union[ typing.Literal["COMPLETE", "STOP_SEQUENCE", "MAX_TOKENS", "TOOL_CALL", "ERROR", "TIMEOUT"], typing.Any ] ================================================ FILE: src/cohere/types/chat_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_call import ToolCall class ChatMessage(UncheckedBaseModel): """ Represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. """ message: str = pydantic.Field() """ Contents of the chat message. """ tool_calls: typing.Optional[typing.List[ToolCall]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_end_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_message_end_event_delta import ChatMessageEndEventDelta from .chat_stream_event_type import ChatStreamEventType class ChatMessageEndEvent(ChatStreamEventType): """ A streamed event which signifies that the chat message has ended. """ id: typing.Optional[str] = None delta: typing.Optional[ChatMessageEndEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_end_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_finish_reason import ChatFinishReason from .usage import Usage class ChatMessageEndEventDelta(UncheckedBaseModel): error: typing.Optional[str] = pydantic.Field(default=None) """ An error message if an error occurred during the generation. """ finish_reason: typing.Optional[ChatFinishReason] = None usage: typing.Optional[Usage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_start_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_message_start_event_delta import ChatMessageStartEventDelta from .chat_stream_event_type import ChatStreamEventType class ChatMessageStartEvent(ChatStreamEventType): """ A streamed event which signifies that a stream has started. """ id: typing.Optional[str] = pydantic.Field(default=None) """ Unique identifier for the generated reply. """ delta: typing.Optional[ChatMessageStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_start_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_message_start_event_delta_message import ChatMessageStartEventDeltaMessage class ChatMessageStartEventDelta(UncheckedBaseModel): message: typing.Optional[ChatMessageStartEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_start_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatMessageStartEventDeltaMessage(UncheckedBaseModel): role: typing.Optional[typing.Literal["assistant"]] = pydantic.Field(default=None) """ The role of the message. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_message_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .assistant_message_v2content import AssistantMessageV2Content from .citation import Citation from .system_message_v2content import SystemMessageV2Content from .tool_call_v2 import ToolCallV2 from .tool_message_v2content import ToolMessageV2Content from .user_message_v2content import UserMessageV2Content class UserChatMessageV2(UncheckedBaseModel): """ Represents a single message in the chat history from a given role. """ role: typing.Literal["user"] = "user" content: UserMessageV2Content if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class AssistantChatMessageV2(UncheckedBaseModel): """ Represents a single message in the chat history from a given role. """ role: typing.Literal["assistant"] = "assistant" tool_calls: typing.Optional[typing.List[ToolCallV2]] = None tool_plan: typing.Optional[str] = None content: typing.Optional[AssistantMessageV2Content] = None citations: typing.Optional[typing.List[Citation]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class SystemChatMessageV2(UncheckedBaseModel): """ Represents a single message in the chat history from a given role. """ role: typing.Literal["system"] = "system" content: SystemMessageV2Content if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolChatMessageV2(UncheckedBaseModel): """ Represents a single message in the chat history from a given role. """ role: typing.Literal["tool"] = "tool" tool_call_id: str content: ToolMessageV2Content if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ChatMessageV2 = typing_extensions.Annotated[ typing.Union[UserChatMessageV2, AssistantChatMessageV2, SystemChatMessageV2, ToolChatMessageV2], UnionMetadata(discriminant="role"), ] ================================================ FILE: src/cohere/types/chat_messages.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .chat_message_v2 import ChatMessageV2 ChatMessages = typing.List[ChatMessageV2] """ A list of chat messages in chronological order, representing a conversation between the user and the model. Messages can be from `User`, `Assistant`, `Tool` and `System` roles. Learn more about messages and roles in [the Chat API guide](https://docs.cohere.com/v2/docs/chat-api). """ ================================================ FILE: src/cohere/types/chat_request_citation_quality.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatRequestCitationQuality = typing.Union[typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any] ================================================ FILE: src/cohere/types/chat_request_prompt_truncation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatRequestPromptTruncation = typing.Union[typing.Literal["OFF", "AUTO", "AUTO_PRESERVE_ORDER"], typing.Any] ================================================ FILE: src/cohere/types/chat_request_safety_mode.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "NONE"], typing.Any] ================================================ FILE: src/cohere/types/chat_search_queries_generation_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_search_query import ChatSearchQuery from .chat_stream_event import ChatStreamEvent class ChatSearchQueriesGenerationEvent(ChatStreamEvent): search_queries: typing.List[ChatSearchQuery] = pydantic.Field() """ Generated search queries, meant to be used as part of the RAG flow. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_search_query.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatSearchQuery(UncheckedBaseModel): """ The generated search query. Contains the text of the query and a unique identifier for the query. """ text: str = pydantic.Field() """ The text of the search query. """ generation_id: str = pydantic.Field() """ Unique identifier for the generated search query. Useful for submitting feedback. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_search_result.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_search_query import ChatSearchQuery from .chat_search_result_connector import ChatSearchResultConnector class ChatSearchResult(UncheckedBaseModel): search_query: typing.Optional[ChatSearchQuery] = None connector: ChatSearchResultConnector = pydantic.Field() """ The connector from which this result comes from. """ document_ids: typing.List[str] = pydantic.Field() """ Identifiers of documents found by this search query. """ error_message: typing.Optional[str] = pydantic.Field(default=None) """ An error message if the search failed. """ continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None) """ Whether a chat request should continue or not if the request to this connector fails. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_search_result_connector.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatSearchResultConnector(UncheckedBaseModel): """ The connector used for fetching documents. """ id: str = pydantic.Field() """ The identifier of the connector. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_search_results_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_document import ChatDocument from .chat_search_result import ChatSearchResult from .chat_stream_event import ChatStreamEvent class ChatSearchResultsEvent(ChatStreamEvent): search_results: typing.Optional[typing.List[ChatSearchResult]] = pydantic.Field(default=None) """ Conducted searches and the ids of documents retrieved from each of them. """ documents: typing.Optional[typing.List[ChatDocument]] = pydantic.Field(default=None) """ Documents fetched from searches or provided by the user. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_stream_end_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason from .chat_stream_event import ChatStreamEvent from .non_streamed_chat_response import NonStreamedChatResponse class ChatStreamEndEvent(ChatStreamEvent): finish_reason: ChatStreamEndEventFinishReason = pydantic.Field() """ - `COMPLETE` - the model sent back a finished reply - `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter - `ERROR` - something went wrong when generating the reply - `ERROR_TOXIC` - the model generated a reply that was deemed toxic """ response: NonStreamedChatResponse = pydantic.Field() """ The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_stream_end_event_finish_reason.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatStreamEndEventFinishReason = typing.Union[ typing.Literal["COMPLETE", "ERROR_LIMIT", "MAX_TOKENS", "ERROR", "ERROR_TOXIC"], typing.Any ] ================================================ FILE: src/cohere/types/chat_stream_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatStreamEvent(UncheckedBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_stream_event_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatStreamEventType(UncheckedBaseModel): """ The streamed event types """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_stream_request_citation_quality.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatStreamRequestCitationQuality = typing.Union[ typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any ] ================================================ FILE: src/cohere/types/chat_stream_request_prompt_truncation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatStreamRequestPromptTruncation = typing.Union[typing.Literal["OFF", "AUTO", "AUTO_PRESERVE_ORDER"], typing.Any] ================================================ FILE: src/cohere/types/chat_stream_request_safety_mode.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ChatStreamRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "NONE"], typing.Any] ================================================ FILE: src/cohere/types/chat_stream_start_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event import ChatStreamEvent class ChatStreamStartEvent(ChatStreamEvent): generation_id: str = pydantic.Field() """ Unique identifier for the generated reply. Useful for submitting feedback. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_text_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatTextContent(UncheckedBaseModel): """ Text content of the message. """ text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_text_generation_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event import ChatStreamEvent class ChatTextGenerationEvent(ChatStreamEvent): text: str = pydantic.Field() """ The next batch of text generated by the model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_text_response_format.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatTextResponseFormat(UncheckedBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_text_response_format_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatTextResponseFormatV2(UncheckedBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_thinking_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatThinkingContent(UncheckedBaseModel): """ Thinking content of the message. This will be present when `thinking` is enabled, and will contain the models internal reasoning. """ thinking: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_delta_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType from .chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta class ChatToolCallDeltaEvent(ChatStreamEventType): """ A streamed event delta which signifies a delta in tool call arguments. """ index: typing.Optional[int] = None delta: typing.Optional[ChatToolCallDeltaEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_delta_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_tool_call_delta_event_delta_message import ChatToolCallDeltaEventDeltaMessage class ChatToolCallDeltaEventDelta(UncheckedBaseModel): message: typing.Optional[ChatToolCallDeltaEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_delta_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_tool_call_delta_event_delta_message_tool_calls import ChatToolCallDeltaEventDeltaMessageToolCalls class ChatToolCallDeltaEventDeltaMessage(UncheckedBaseModel): tool_calls: typing.Optional[ChatToolCallDeltaEventDeltaMessageToolCalls] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_delta_event_delta_message_tool_calls.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_tool_call_delta_event_delta_message_tool_calls_function import ( ChatToolCallDeltaEventDeltaMessageToolCallsFunction, ) class ChatToolCallDeltaEventDeltaMessageToolCalls(UncheckedBaseModel): function: typing.Optional[ChatToolCallDeltaEventDeltaMessageToolCallsFunction] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_delta_event_delta_message_tool_calls_function.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatToolCallDeltaEventDeltaMessageToolCallsFunction(UncheckedBaseModel): arguments: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_end_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType class ChatToolCallEndEvent(ChatStreamEventType): """ A streamed event delta which signifies a tool call has finished streaming. """ index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_start_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType from .chat_tool_call_start_event_delta import ChatToolCallStartEventDelta class ChatToolCallStartEvent(ChatStreamEventType): """ A streamed event delta which signifies a tool call has started streaming. """ index: typing.Optional[int] = None delta: typing.Optional[ChatToolCallStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_start_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_tool_call_start_event_delta_message import ChatToolCallStartEventDeltaMessage class ChatToolCallStartEventDelta(UncheckedBaseModel): message: typing.Optional[ChatToolCallStartEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_call_start_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_call_v2 import ToolCallV2 class ChatToolCallStartEventDeltaMessage(UncheckedBaseModel): tool_calls: typing.Optional[ToolCallV2] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_calls_chunk_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event import ChatStreamEvent from .tool_call_delta import ToolCallDelta class ChatToolCallsChunkEvent(ChatStreamEvent): tool_call_delta: ToolCallDelta text: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_calls_generation_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event import ChatStreamEvent from .tool_call import ToolCall class ChatToolCallsGenerationEvent(ChatStreamEvent): text: typing.Optional[str] = pydantic.Field(default=None) """ The text generated related to the tool calls generated """ tool_calls: typing.List[ToolCall] if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_result import ToolResult class ChatToolMessage(UncheckedBaseModel): """ Represents tool result in the chat history. """ tool_results: typing.Optional[typing.List[ToolResult]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_plan_delta_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType from .chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta class ChatToolPlanDeltaEvent(ChatStreamEventType): """ A streamed event which contains a delta of tool plan text. """ delta: typing.Optional[ChatToolPlanDeltaEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_plan_delta_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .chat_tool_plan_delta_event_delta_message import ChatToolPlanDeltaEventDeltaMessage class ChatToolPlanDeltaEventDelta(UncheckedBaseModel): message: typing.Optional[ChatToolPlanDeltaEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_plan_delta_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatToolPlanDeltaEventDeltaMessage(UncheckedBaseModel): tool_plan: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/chat_tool_source.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ChatToolSource(UncheckedBaseModel): id: typing.Optional[str] = pydantic.Field(default=None) """ The unique identifier of the document """ tool_output: typing.Optional[typing.Dict[str, typing.Any]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/check_api_key_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class CheckApiKeyResponse(UncheckedBaseModel): valid: bool organization_id: typing.Optional[str] = None owner_id: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .citation_type import CitationType from .source import Source class Citation(UncheckedBaseModel): """ Citation information containing sources and the text cited. """ start: typing.Optional[int] = pydantic.Field(default=None) """ Start index of the cited snippet in the original source text. """ end: typing.Optional[int] = pydantic.Field(default=None) """ End index of the cited snippet in the original source text. """ text: typing.Optional[str] = pydantic.Field(default=None) """ Text snippet that is being cited. """ sources: typing.Optional[typing.List[Source]] = None content_index: typing.Optional[int] = pydantic.Field(default=None) """ Index of the content block in which this citation appears. """ type: typing.Optional[CitationType] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_end_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType class CitationEndEvent(ChatStreamEventType): """ A streamed event which signifies a citation has finished streaming. """ index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_options.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .citation_options_mode import CitationOptionsMode class CitationOptions(UncheckedBaseModel): """ Options for controlling citation generation. """ mode: typing.Optional[CitationOptionsMode] = pydantic.Field(default=None) """ Defaults to `"enabled"`. Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_options_mode.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing CitationOptionsMode = typing.Union[typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any] ================================================ FILE: src/cohere/types/citation_start_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .chat_stream_event_type import ChatStreamEventType from .citation_start_event_delta import CitationStartEventDelta class CitationStartEvent(ChatStreamEventType): """ A streamed event which signifies a citation has been created. """ index: typing.Optional[int] = None delta: typing.Optional[CitationStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_start_event_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .citation_start_event_delta_message import CitationStartEventDeltaMessage class CitationStartEventDelta(UncheckedBaseModel): message: typing.Optional[CitationStartEventDeltaMessage] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_start_event_delta_message.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .citation import Citation class CitationStartEventDeltaMessage(UncheckedBaseModel): citations: typing.Optional[Citation] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/citation_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing CitationType = typing.Union[typing.Literal["TEXT_CONTENT", "THINKING_CONTENT", "PLAN"], typing.Any] ================================================ FILE: src/cohere/types/classify_data_metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .label_metric import LabelMetric class ClassifyDataMetrics(UncheckedBaseModel): label_metrics: typing.Optional[typing.List[LabelMetric]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/classify_example.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ClassifyExample(UncheckedBaseModel): text: typing.Optional[str] = None label: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/classify_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ClassifyRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any] ================================================ FILE: src/cohere/types/classify_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .classify_response_classifications_item import ClassifyResponseClassificationsItem class ClassifyResponse(UncheckedBaseModel): id: str classifications: typing.List[ClassifyResponseClassificationsItem] meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/classify_response_classifications_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .classify_response_classifications_item_classification_type import ( ClassifyResponseClassificationsItemClassificationType, ) from .classify_response_classifications_item_labels_value import ClassifyResponseClassificationsItemLabelsValue class ClassifyResponseClassificationsItem(UncheckedBaseModel): id: str input: typing.Optional[str] = pydantic.Field(default=None) """ The input text that was classified """ prediction: typing.Optional[str] = pydantic.Field(default=None) """ The predicted label for the associated query (only filled for single-label models) """ predictions: typing.List[str] = pydantic.Field() """ An array containing the predicted labels for the associated query (only filled for single-label classification) """ confidence: typing.Optional[float] = pydantic.Field(default=None) """ The confidence score for the top predicted class (only filled for single-label classification) """ confidences: typing.List[float] = pydantic.Field() """ An array containing the confidence scores of all the predictions in the same order """ labels: typing.Dict[str, ClassifyResponseClassificationsItemLabelsValue] = pydantic.Field() """ A map containing each label and its confidence score according to the classifier. All the confidence scores add up to 1 for single-label classification. For multi-label classification the label confidences are independent of each other, so they don't have to sum up to 1. """ classification_type: ClassifyResponseClassificationsItemClassificationType = pydantic.Field() """ The type of classification performed """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/classify_response_classifications_item_classification_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ClassifyResponseClassificationsItemClassificationType = typing.Union[ typing.Literal["single-label", "multi-label"], typing.Any ] ================================================ FILE: src/cohere/types/classify_response_classifications_item_labels_value.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ClassifyResponseClassificationsItemLabelsValue(UncheckedBaseModel): confidence: typing.Optional[float] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/compatible_endpoint.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing CompatibleEndpoint = typing.Union[ typing.Literal["chat", "embed", "classify", "summarize", "rerank", "rate", "generate"], typing.Any ] ================================================ FILE: src/cohere/types/connector.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .connector_auth_status import ConnectorAuthStatus from .connector_o_auth import ConnectorOAuth class Connector(UncheckedBaseModel): """ A connector allows you to integrate data sources with the '/chat' endpoint to create grounded generations with citations to the data source. documents to help answer users. """ id: str = pydantic.Field() """ The unique identifier of the connector (used in both `/connectors` & `/chat` endpoints). This is automatically created from the name of the connector upon registration. """ organization_id: typing.Optional[str] = pydantic.Field(default=None) """ The organization to which this connector belongs. This is automatically set to the organization of the user who created the connector. """ name: str = pydantic.Field() """ A human-readable name for the connector. """ description: typing.Optional[str] = pydantic.Field(default=None) """ A description of the connector. """ url: typing.Optional[str] = pydantic.Field(default=None) """ The URL of the connector that will be used to search for documents. """ created_at: dt.datetime = pydantic.Field() """ The UTC time at which the connector was created. """ updated_at: dt.datetime = pydantic.Field() """ The UTC time at which the connector was last updated. """ excludes: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ A list of fields to exclude from the prompt (fields remain in the document). """ auth_type: typing.Optional[str] = pydantic.Field(default=None) """ The type of authentication/authorization used by the connector. Possible values: [oauth, service_auth] """ oauth: typing.Optional[ConnectorOAuth] = pydantic.Field(default=None) """ The OAuth 2.0 configuration for the connector. """ auth_status: typing.Optional[ConnectorAuthStatus] = pydantic.Field(default=None) """ The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet. """ active: typing.Optional[bool] = pydantic.Field(default=None) """ Whether the connector is active or not. """ continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None) """ Whether a chat request should continue or not if the request to this connector fails. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/connector_auth_status.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ConnectorAuthStatus = typing.Union[typing.Literal["valid", "expired"], typing.Any] ================================================ FILE: src/cohere/types/connector_o_auth.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ConnectorOAuth(UncheckedBaseModel): client_id: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 client ID. This field is encrypted at rest. """ client_secret: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. """ authorize_url: str = pydantic.Field() """ The OAuth 2.0 /authorize endpoint to use when users authorize the connector. """ token_url: str = pydantic.Field() """ The OAuth 2.0 /token endpoint to use when users authorize the connector. """ scope: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth scopes to request when users authorize the connector. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/content.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .image_url import ImageUrl class TextContent(UncheckedBaseModel): """ A Content block which contains information about the content type and the content itself. """ type: typing.Literal["text"] = "text" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ImageUrlContent(UncheckedBaseModel): """ A Content block which contains information about the content type and the content itself. """ type: typing.Literal["image_url"] = "image_url" image_url: ImageUrl if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow Content = typing_extensions.Annotated[typing.Union[TextContent, ImageUrlContent], UnionMetadata(discriminant="type")] ================================================ FILE: src/cohere/types/create_connector_o_auth.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class CreateConnectorOAuth(UncheckedBaseModel): client_id: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 client ID. This fields is encrypted at rest. """ client_secret: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. """ authorize_url: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 /authorize endpoint to use when users authorize the connector. """ token_url: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 /token endpoint to use when users authorize the connector. """ scope: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth scopes to request when users authorize the connector. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/create_connector_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .connector import Connector class CreateConnectorResponse(UncheckedBaseModel): connector: Connector if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/create_connector_service_auth.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .auth_token_type import AuthTokenType class CreateConnectorServiceAuth(UncheckedBaseModel): type: AuthTokenType token: str = pydantic.Field() """ The token that will be used in the HTTP Authorization header when making requests to the connector. This field is encrypted at rest and never returned in a response. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/create_embed_job_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta class CreateEmbedJobResponse(UncheckedBaseModel): """ Response from creating an embed job. """ job_id: str meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/dataset.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.serialization import FieldMetadata from ..core.unchecked_base_model import UncheckedBaseModel from .dataset_part import DatasetPart from .dataset_type import DatasetType from .dataset_validation_status import DatasetValidationStatus class Dataset(UncheckedBaseModel): id: str = pydantic.Field() """ The dataset ID """ name: str = pydantic.Field() """ The name of the dataset """ created_at: dt.datetime = pydantic.Field() """ The creation date """ updated_at: dt.datetime = pydantic.Field() """ The last update date """ dataset_type: DatasetType validation_status: DatasetValidationStatus validation_error: typing.Optional[str] = pydantic.Field(default=None) """ Errors found during validation """ schema_: typing_extensions.Annotated[ typing.Optional[str], FieldMetadata(alias="schema"), pydantic.Field(alias="schema", description="the avro schema of the dataset"), ] = None required_fields: typing.Optional[typing.List[str]] = None preserve_fields: typing.Optional[typing.List[str]] = None dataset_parts: typing.Optional[typing.List[DatasetPart]] = pydantic.Field(default=None) """ the underlying files that make up the dataset """ validation_warnings: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ warnings found during validation """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/dataset_part.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class DatasetPart(UncheckedBaseModel): id: str = pydantic.Field() """ The dataset part ID """ name: str = pydantic.Field() """ The name of the dataset part """ url: typing.Optional[str] = pydantic.Field(default=None) """ The download url of the file """ index: typing.Optional[int] = pydantic.Field(default=None) """ The index of the file """ size_bytes: typing.Optional[int] = pydantic.Field(default=None) """ The size of the file in bytes """ num_rows: typing.Optional[int] = pydantic.Field(default=None) """ The number of rows in the file """ original_url: typing.Optional[str] = pydantic.Field(default=None) """ The download url of the original file """ samples: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ The first few rows of the parsed file """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/dataset_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing DatasetType = typing.Union[ typing.Literal[ "embed-input", "embed-result", "cluster-result", "cluster-outliers", "reranker-finetune-input", "single-label-classification-finetune-input", "chat-finetune-input", "multi-label-classification-finetune-input", "batch-chat-input", "batch-openai-chat-input", "batch-embed-v2-input", "batch-chat-v2-input", ], typing.Any, ] ================================================ FILE: src/cohere/types/dataset_validation_status.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing DatasetValidationStatus = typing.Union[ typing.Literal["unknown", "queued", "processing", "failed", "validated", "skipped"], typing.Any ] ================================================ FILE: src/cohere/types/delete_connector_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing DeleteConnectorResponse = typing.Dict[str, typing.Any] ================================================ FILE: src/cohere/types/detokenize_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta class DetokenizeResponse(UncheckedBaseModel): text: str = pydantic.Field() """ A string representing the list of tokens. """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/document.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class Document(UncheckedBaseModel): """ Relevant information that could be used by the model to generate a more accurate reply. The content of each document are generally short (should be under 300 words). Metadata should be used to provide additional information, both the key name and the value will be passed to the model. """ data: typing.Dict[str, typing.Any] = pydantic.Field() """ A relevant document that the model can cite to generate a more accurate reply. Each document is a string-any dictionary. """ id: typing.Optional[str] = pydantic.Field(default=None) """ Unique identifier for this document which will be referenced in citations. If not provided an ID will be automatically generated. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/document_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .document import Document class DocumentContent(UncheckedBaseModel): """ Document content. """ document: Document if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_by_type_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings from .embed_by_type_response_response_type import EmbedByTypeResponseResponseType from .image import Image class EmbedByTypeResponse(UncheckedBaseModel): response_type: typing.Optional[EmbedByTypeResponseResponseType] = None id: str embeddings: EmbedByTypeResponseEmbeddings = pydantic.Field() """ An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. """ texts: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ The text entries for which embeddings were returned. """ images: typing.Optional[typing.List[Image]] = pydantic.Field(default=None) """ The image entries for which embeddings were returned. """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_by_type_response_embeddings.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.serialization import FieldMetadata from ..core.unchecked_base_model import UncheckedBaseModel class EmbedByTypeResponseEmbeddings(UncheckedBaseModel): """ An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. """ float_: typing_extensions.Annotated[ typing.Optional[typing.List[typing.List[float]]], FieldMetadata(alias="float"), pydantic.Field(alias="float", description="An array of float embeddings."), ] = None int8: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None) """ An array of signed int8 embeddings. Each value is between -128 and 127. """ uint8: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None) """ An array of unsigned int8 embeddings. Each value is between 0 and 255. """ binary: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None) """ An array of packed signed binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between -128 and 127. """ ubinary: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None) """ An array of packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between 0 and 255. """ base64: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ An array of base64 embeddings. Each string is the result of appending the float embedding bytes together and base64 encoding that. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_by_type_response_response_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbedByTypeResponseResponseType = typing.Union[typing.Literal["embeddings_floats", "embeddings_by_type"], typing.Any] ================================================ FILE: src/cohere/types/embed_content.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .embed_image_url import EmbedImageUrl class ImageUrlEmbedContent(UncheckedBaseModel): type: typing.Literal["image_url"] = "image_url" image_url: typing.Optional[EmbedImageUrl] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class TextEmbedContent(UncheckedBaseModel): type: typing.Literal["text"] = "text" text: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow EmbedContent = typing_extensions.Annotated[ typing.Union[ImageUrlEmbedContent, TextEmbedContent], UnionMetadata(discriminant="type") ] ================================================ FILE: src/cohere/types/embed_floats_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .image import Image class EmbedFloatsResponse(UncheckedBaseModel): id: str embeddings: typing.List[typing.List[float]] = pydantic.Field() """ An array of embeddings, where each embedding is an array of floats. The length of the `embeddings` array will be the same as the length of the original `texts` array. """ texts: typing.List[str] = pydantic.Field() """ The text entries for which embeddings were returned. """ images: typing.Optional[typing.List[Image]] = pydantic.Field(default=None) """ The image entries for which embeddings were returned. """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_image.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .embed_image_url import EmbedImageUrl class EmbedImage(UncheckedBaseModel): """ Image content of the input. Supported with Embed v3.0 and newer models. """ image_url: typing.Optional[EmbedImageUrl] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_image_url.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class EmbedImageUrl(UncheckedBaseModel): """ Base64 url of image. """ url: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_input.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .embed_content import EmbedContent class EmbedInput(UncheckedBaseModel): content: typing.List[EmbedContent] = pydantic.Field() """ An array of objects containing the input data for the model to embed. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_input_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbedInputType = typing.Union[ typing.Literal["search_document", "search_query", "classification", "clustering", "image"], typing.Any ] ================================================ FILE: src/cohere/types/embed_job.py ================================================ # This file was auto-generated by Fern from our API Definition. import datetime as dt import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .embed_job_status import EmbedJobStatus from .embed_job_truncate import EmbedJobTruncate class EmbedJob(UncheckedBaseModel): job_id: str = pydantic.Field() """ ID of the embed job """ name: typing.Optional[str] = pydantic.Field(default=None) """ The name of the embed job """ status: EmbedJobStatus = pydantic.Field() """ The status of the embed job """ created_at: dt.datetime = pydantic.Field() """ The creation date of the embed job """ input_dataset_id: str = pydantic.Field() """ ID of the input dataset """ output_dataset_id: typing.Optional[str] = pydantic.Field(default=None) """ ID of the resulting output dataset """ model: str = pydantic.Field() """ ID of the model used to embed """ truncate: EmbedJobTruncate = pydantic.Field() """ The truncation option used """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embed_job_status.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbedJobStatus = typing.Union[typing.Literal["processing", "complete", "cancelling", "cancelled", "failed"], typing.Any] ================================================ FILE: src/cohere/types/embed_job_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbedJobTruncate = typing.Union[typing.Literal["START", "END"], typing.Any] ================================================ FILE: src/cohere/types/embed_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbedRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any] ================================================ FILE: src/cohere/types/embed_response.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .api_meta import ApiMeta from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings from .image import Image class EmbeddingsFloatsEmbedResponse(UncheckedBaseModel): response_type: typing.Literal["embeddings_floats"] = "embeddings_floats" id: str embeddings: typing.List[typing.List[float]] texts: typing.List[str] images: typing.Optional[typing.List[Image]] = None meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class EmbeddingsByTypeEmbedResponse(UncheckedBaseModel): response_type: typing.Literal["embeddings_by_type"] = "embeddings_by_type" id: str embeddings: EmbedByTypeResponseEmbeddings texts: typing.Optional[typing.List[str]] = None images: typing.Optional[typing.List[Image]] = None meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow EmbedResponse = typing_extensions.Annotated[ typing.Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse], UnionMetadata(discriminant="response_type"), ] ================================================ FILE: src/cohere/types/embed_text.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class EmbedText(UncheckedBaseModel): """ Text content of the input. """ text: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/embedding_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing EmbeddingType = typing.Union[typing.Literal["float", "int8", "uint8", "binary", "ubinary", "base64"], typing.Any] ================================================ FILE: src/cohere/types/finetune_dataset_metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class FinetuneDatasetMetrics(UncheckedBaseModel): trainable_token_count: typing.Optional[int] = pydantic.Field(default=None) """ The number of tokens of valid examples that can be used for training. """ total_examples: typing.Optional[int] = pydantic.Field(default=None) """ The overall number of examples. """ train_examples: typing.Optional[int] = pydantic.Field(default=None) """ The number of training examples. """ train_size_bytes: typing.Optional[int] = pydantic.Field(default=None) """ The size in bytes of all training examples. """ eval_examples: typing.Optional[int] = pydantic.Field(default=None) """ Number of evaluation examples. """ eval_size_bytes: typing.Optional[int] = pydantic.Field(default=None) """ The size in bytes of all eval examples. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/finish_reason.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing FinishReason = typing.Union[ typing.Literal[ "COMPLETE", "STOP_SEQUENCE", "ERROR", "ERROR_TOXIC", "ERROR_LIMIT", "USER_CANCEL", "MAX_TOKENS", "TIMEOUT" ], typing.Any, ] ================================================ FILE: src/cohere/types/generate_request_return_likelihoods.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing GenerateRequestReturnLikelihoods = typing.Union[typing.Literal["GENERATION", "ALL", "NONE"], typing.Any] ================================================ FILE: src/cohere/types/generate_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing GenerateRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any] ================================================ FILE: src/cohere/types/generate_stream_end.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .finish_reason import FinishReason from .generate_stream_end_response import GenerateStreamEndResponse from .generate_stream_event import GenerateStreamEvent class GenerateStreamEnd(GenerateStreamEvent): is_finished: bool finish_reason: typing.Optional[FinishReason] = None response: GenerateStreamEndResponse if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/generate_stream_end_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .single_generation_in_stream import SingleGenerationInStream class GenerateStreamEndResponse(UncheckedBaseModel): id: str prompt: typing.Optional[str] = None generations: typing.Optional[typing.List[SingleGenerationInStream]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/generate_stream_error.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .finish_reason import FinishReason from .generate_stream_event import GenerateStreamEvent class GenerateStreamError(GenerateStreamEvent): index: typing.Optional[int] = pydantic.Field(default=None) """ Refers to the nth generation. Only present when `num_generations` is greater than zero. """ is_finished: bool finish_reason: FinishReason err: str = pydantic.Field() """ Error message """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/generate_stream_event.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class GenerateStreamEvent(UncheckedBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/generate_stream_request_return_likelihoods.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing GenerateStreamRequestReturnLikelihoods = typing.Union[typing.Literal["GENERATION", "ALL", "NONE"], typing.Any] ================================================ FILE: src/cohere/types/generate_stream_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing GenerateStreamRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any] ================================================ FILE: src/cohere/types/generate_stream_text.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from .generate_stream_event import GenerateStreamEvent class GenerateStreamText(GenerateStreamEvent): text: str = pydantic.Field() """ A segment of text of the generation. """ index: typing.Optional[int] = pydantic.Field(default=None) """ Refers to the nth generation. Only present when `num_generations` is greater than zero, and only when text responses are being streamed. """ is_finished: bool if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/generate_streamed_response.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .finish_reason import FinishReason from .generate_stream_end_response import GenerateStreamEndResponse class TextGenerationGenerateStreamedResponse(UncheckedBaseModel): """ Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse. """ event_type: typing.Literal["text-generation"] = "text-generation" text: str index: typing.Optional[int] = None is_finished: bool if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class StreamEndGenerateStreamedResponse(UncheckedBaseModel): """ Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse. """ event_type: typing.Literal["stream-end"] = "stream-end" is_finished: bool finish_reason: typing.Optional[FinishReason] = None response: GenerateStreamEndResponse if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class StreamErrorGenerateStreamedResponse(UncheckedBaseModel): """ Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse. """ event_type: typing.Literal["stream-error"] = "stream-error" index: typing.Optional[int] = None is_finished: bool finish_reason: FinishReason err: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow GenerateStreamedResponse = typing_extensions.Annotated[ typing.Union[ TextGenerationGenerateStreamedResponse, StreamEndGenerateStreamedResponse, StreamErrorGenerateStreamedResponse ], UnionMetadata(discriminant="event_type"), ] ================================================ FILE: src/cohere/types/generation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .single_generation import SingleGeneration class Generation(UncheckedBaseModel): id: str prompt: typing.Optional[str] = pydantic.Field(default=None) """ Prompt used for generations. """ generations: typing.List[SingleGeneration] = pydantic.Field() """ List of generated results """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/get_connector_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .connector import Connector class GetConnectorResponse(UncheckedBaseModel): connector: Connector if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/get_model_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .compatible_endpoint import CompatibleEndpoint from .get_model_response_sampling_defaults import GetModelResponseSamplingDefaults class GetModelResponse(UncheckedBaseModel): """ Contains information about the model and which API endpoints it can be used with. """ name: typing.Optional[str] = pydantic.Field(default=None) """ Specify this name in the `model` parameter of API requests to use your chosen model. """ is_deprecated: typing.Optional[bool] = pydantic.Field(default=None) """ Whether the model is deprecated or not. """ endpoints: typing.Optional[typing.List[CompatibleEndpoint]] = pydantic.Field(default=None) """ The API endpoints that the model is compatible with. """ finetuned: typing.Optional[bool] = pydantic.Field(default=None) """ Whether the model has been fine-tuned or not. """ context_length: typing.Optional[float] = pydantic.Field(default=None) """ The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. """ tokenizer_url: typing.Optional[str] = pydantic.Field(default=None) """ Public URL to the tokenizer's configuration file. """ default_endpoints: typing.Optional[typing.List[CompatibleEndpoint]] = pydantic.Field(default=None) """ The API endpoints that the model is default to. """ features: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ The features that the model supports. """ sampling_defaults: typing.Optional[GetModelResponseSamplingDefaults] = pydantic.Field(default=None) """ Default sampling parameters for this model when omitted from API requests. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/get_model_response_sampling_defaults.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class GetModelResponseSamplingDefaults(UncheckedBaseModel): """ Default sampling parameters for this model when omitted from API requests. """ temperature: typing.Optional[float] = None k: typing.Optional[int] = None p: typing.Optional[float] = None frequency_penalty: typing.Optional[float] = None presence_penalty: typing.Optional[float] = None max_tokens_per_doc: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/image.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class Image(UncheckedBaseModel): width: int = pydantic.Field() """ Width of the image in pixels """ height: int = pydantic.Field() """ Height of the image in pixels """ format: str = pydantic.Field() """ Format of the image """ bit_depth: int = pydantic.Field() """ Bit depth of the image """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/image_content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .image_url import ImageUrl class ImageContent(UncheckedBaseModel): """ Image content of the message. """ image_url: ImageUrl if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/image_url.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .image_url_detail import ImageUrlDetail class ImageUrl(UncheckedBaseModel): url: str = pydantic.Field() """ URL of an image. Can be either a base64 data URI or a web URL. """ detail: typing.Optional[ImageUrlDetail] = pydantic.Field(default=None) """ Controls the level of detail in image processing. `"auto"` is the default and lets the system choose, `"low"` is faster but less detailed, and `"high"` preserves maximum detail. You can save tokens and speed up responses by using detail: `"low"`. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/image_url_detail.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ImageUrlDetail = typing.Union[typing.Literal["auto", "low", "high"], typing.Any] ================================================ FILE: src/cohere/types/json_response_format.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.serialization import FieldMetadata from ..core.unchecked_base_model import UncheckedBaseModel class JsonResponseFormat(UncheckedBaseModel): schema_: typing_extensions.Annotated[ typing.Optional[typing.Dict[str, typing.Any]], FieldMetadata(alias="schema"), pydantic.Field( alias="schema", description='A JSON schema object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information.\nExample (required name and age object):\n```json\n{\n "type": "object",\n "properties": {\n "name": {"type": "string"},\n "age": {"type": "integer"}\n },\n "required": ["name", "age"]\n}\n```\n\n**Note**: This field must not be specified when the `type` is set to `"text"`.', ), ] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/json_response_format_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class JsonResponseFormatV2(UncheckedBaseModel): json_schema: typing.Optional[typing.Dict[str, typing.Any]] = pydantic.Field(default=None) """ A [JSON schema](https://json-schema.org/overview/what-is-jsonschema) object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information. Example (required name and age object): ```json { "type": "object", "properties": { "name": {"type": "string"}, "age": {"type": "integer"} }, "required": ["name", "age"] } ``` **Note**: This field must not be specified when the `type` is set to `"text"`. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/label_metric.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class LabelMetric(UncheckedBaseModel): total_examples: typing.Optional[int] = pydantic.Field(default=None) """ Total number of examples for this label """ label: typing.Optional[str] = pydantic.Field(default=None) """ value of the label """ samples: typing.Optional[typing.List[str]] = pydantic.Field(default=None) """ samples for this label """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/list_connectors_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .connector import Connector class ListConnectorsResponse(UncheckedBaseModel): connectors: typing.List[Connector] total_count: typing.Optional[float] = pydantic.Field(default=None) """ Total number of connectors. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/list_embed_job_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .embed_job import EmbedJob class ListEmbedJobResponse(UncheckedBaseModel): embed_jobs: typing.Optional[typing.List[EmbedJob]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/list_models_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .get_model_response import GetModelResponse class ListModelsResponse(UncheckedBaseModel): models: typing.List[GetModelResponse] next_page_token: typing.Optional[str] = pydantic.Field(default=None) """ A token to retrieve the next page of results. Provide in the page_token parameter of the next request. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/logprob_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class LogprobItem(UncheckedBaseModel): text: typing.Optional[str] = pydantic.Field(default=None) """ The text chunk for which the log probabilities was calculated. """ token_ids: typing.List[int] = pydantic.Field() """ The token ids of each token used to construct the text chunk. """ logprobs: typing.Optional[typing.List[float]] = pydantic.Field(default=None) """ The log probability of each token used to construct the text chunk. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/message.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .tool_call import ToolCall from .tool_result import ToolResult class ChatbotMessage(UncheckedBaseModel): role: typing.Literal["CHATBOT"] = "CHATBOT" message: str tool_calls: typing.Optional[typing.List[ToolCall]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class SystemMessage(UncheckedBaseModel): role: typing.Literal["SYSTEM"] = "SYSTEM" message: str tool_calls: typing.Optional[typing.List[ToolCall]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class UserMessage(UncheckedBaseModel): role: typing.Literal["USER"] = "USER" message: str tool_calls: typing.Optional[typing.List[ToolCall]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolMessage(UncheckedBaseModel): role: typing.Literal["TOOL"] = "TOOL" tool_results: typing.Optional[typing.List[ToolResult]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow Message = typing_extensions.Annotated[ typing.Union[ChatbotMessage, SystemMessage, UserMessage, ToolMessage], UnionMetadata(discriminant="role") ] ================================================ FILE: src/cohere/types/metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .finetune_dataset_metrics import FinetuneDatasetMetrics class Metrics(UncheckedBaseModel): finetune_dataset_metrics: typing.Optional[FinetuneDatasetMetrics] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/non_streamed_chat_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .chat_citation import ChatCitation from .chat_document import ChatDocument from .chat_search_query import ChatSearchQuery from .chat_search_result import ChatSearchResult from .finish_reason import FinishReason from .message import Message from .tool_call import ToolCall class NonStreamedChatResponse(UncheckedBaseModel): text: str = pydantic.Field() """ Contents of the reply generated by the model. """ generation_id: typing.Optional[str] = pydantic.Field(default=None) """ Unique identifier for the generated reply. Useful for submitting feedback. """ response_id: typing.Optional[str] = pydantic.Field(default=None) """ Unique identifier for the response. """ citations: typing.Optional[typing.List[ChatCitation]] = pydantic.Field(default=None) """ Inline citations for the generated reply. """ documents: typing.Optional[typing.List[ChatDocument]] = pydantic.Field(default=None) """ Documents seen by the model when generating the reply. """ is_search_required: typing.Optional[bool] = pydantic.Field(default=None) """ Denotes that a search for documents is required during the RAG flow. """ search_queries: typing.Optional[typing.List[ChatSearchQuery]] = pydantic.Field(default=None) """ Generated search queries, meant to be used as part of the RAG flow. """ search_results: typing.Optional[typing.List[ChatSearchResult]] = pydantic.Field(default=None) """ Documents retrieved from each of the conducted searches. """ finish_reason: typing.Optional[FinishReason] = None tool_calls: typing.Optional[typing.List[ToolCall]] = None chat_history: typing.Optional[typing.List[Message]] = pydantic.Field(default=None) """ A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/o_auth_authorize_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class OAuthAuthorizeResponse(UncheckedBaseModel): redirect_url: typing.Optional[str] = pydantic.Field(default=None) """ The OAuth 2.0 redirect url. Redirect the user to this url to authorize the connector. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/parse_info.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ParseInfo(UncheckedBaseModel): separator: typing.Optional[str] = None delimiter: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/rerank_document.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing RerankDocument = typing.Dict[str, str] ================================================ FILE: src/cohere/types/rerank_request_documents_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .rerank_document import RerankDocument RerankRequestDocumentsItem = typing.Union[str, RerankDocument] ================================================ FILE: src/cohere/types/rerank_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta from .rerank_response_results_item import RerankResponseResultsItem class RerankResponse(UncheckedBaseModel): id: typing.Optional[str] = None results: typing.List[RerankResponseResultsItem] = pydantic.Field() """ An ordered list of ranked documents """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/rerank_response_results_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .rerank_response_results_item_document import RerankResponseResultsItemDocument class RerankResponseResultsItem(UncheckedBaseModel): document: typing.Optional[RerankResponseResultsItemDocument] = pydantic.Field(default=None) """ If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in """ index: int = pydantic.Field() """ Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) """ relevance_score: float = pydantic.Field() """ Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/rerank_response_results_item_document.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class RerankResponseResultsItemDocument(UncheckedBaseModel): """ If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in """ text: str = pydantic.Field() """ The text of the document to rerank """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/reranker_data_metrics.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class RerankerDataMetrics(UncheckedBaseModel): num_train_queries: typing.Optional[int] = pydantic.Field(default=None) """ The number of training queries. """ num_train_relevant_passages: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all relevant passages of valid training examples. """ num_train_hard_negatives: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all hard negatives of valid training examples. """ num_eval_queries: typing.Optional[int] = pydantic.Field(default=None) """ The number of evaluation queries. """ num_eval_relevant_passages: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all relevant passages of valid eval examples. """ num_eval_hard_negatives: typing.Optional[int] = pydantic.Field(default=None) """ The sum of all hard negatives of valid eval examples. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/response_format.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.serialization import FieldMetadata from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata class TextResponseFormat(UncheckedBaseModel): """ Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models. The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`. A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. **Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided). """ type: typing.Literal["text"] = "text" if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class JsonObjectResponseFormat(UncheckedBaseModel): """ Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models. The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`. A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. **Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided). """ type: typing.Literal["json_object"] = "json_object" schema_: typing_extensions.Annotated[ typing.Optional[typing.Dict[str, typing.Any]], FieldMetadata(alias="schema"), pydantic.Field(alias="schema") ] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ResponseFormat = typing_extensions.Annotated[ typing.Union[TextResponseFormat, JsonObjectResponseFormat], UnionMetadata(discriminant="type") ] ================================================ FILE: src/cohere/types/response_format_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata class TextResponseFormatV2(UncheckedBaseModel): """ Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models. The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`. A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. **Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting. **Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters. """ type: typing.Literal["text"] = "text" if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class JsonObjectResponseFormatV2(UncheckedBaseModel): """ Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models. The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`. A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. **Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting. **Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters. """ type: typing.Literal["json_object"] = "json_object" json_schema: typing.Optional[typing.Dict[str, typing.Any]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ResponseFormatV2 = typing_extensions.Annotated[ typing.Union[TextResponseFormatV2, JsonObjectResponseFormatV2], UnionMetadata(discriminant="type") ] ================================================ FILE: src/cohere/types/single_generation.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .single_generation_token_likelihoods_item import SingleGenerationTokenLikelihoodsItem class SingleGeneration(UncheckedBaseModel): id: str text: str index: typing.Optional[int] = pydantic.Field(default=None) """ Refers to the nth generation. Only present when `num_generations` is greater than zero. """ likelihood: typing.Optional[float] = None token_likelihoods: typing.Optional[typing.List[SingleGenerationTokenLikelihoodsItem]] = pydantic.Field(default=None) """ Only returned if `return_likelihoods` is set to `GENERATION` or `ALL`. The likelihood refers to the average log-likelihood of the entire specified string, which is useful for [evaluating the performance of your model](likelihood-eval), especially if you've created a [custom model](https://docs.cohere.com/docs/training-custom-models). Individual token likelihoods provide the log-likelihood of each token. The first token will not have a likelihood. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/single_generation_in_stream.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .finish_reason import FinishReason class SingleGenerationInStream(UncheckedBaseModel): id: str text: str = pydantic.Field() """ Full text of the generation. """ index: typing.Optional[int] = pydantic.Field(default=None) """ Refers to the nth generation. Only present when `num_generations` is greater than zero. """ finish_reason: FinishReason if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/single_generation_token_likelihoods_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class SingleGenerationTokenLikelihoodsItem(UncheckedBaseModel): token: str likelihood: float if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/source.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata class ToolSource(UncheckedBaseModel): """ A source object containing information about the source of the data cited. """ type: typing.Literal["tool"] = "tool" id: typing.Optional[str] = None tool_output: typing.Optional[typing.Dict[str, typing.Any]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class DocumentSource(UncheckedBaseModel): """ A source object containing information about the source of the data cited. """ type: typing.Literal["document"] = "document" id: typing.Optional[str] = None document: typing.Optional[typing.Dict[str, typing.Any]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow Source = typing_extensions.Annotated[typing.Union[ToolSource, DocumentSource], UnionMetadata(discriminant="type")] ================================================ FILE: src/cohere/types/streamed_chat_response.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .chat_citation import ChatCitation from .chat_document import ChatDocument from .chat_search_query import ChatSearchQuery from .chat_search_result import ChatSearchResult from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason from .non_streamed_chat_response import NonStreamedChatResponse from .tool_call import ToolCall from .tool_call_delta import ToolCallDelta class StreamStartStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["stream-start"] = "stream-start" generation_id: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class SearchQueriesGenerationStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["search-queries-generation"] = "search-queries-generation" search_queries: typing.List[ChatSearchQuery] if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class SearchResultsStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["search-results"] = "search-results" search_results: typing.Optional[typing.List[ChatSearchResult]] = None documents: typing.Optional[typing.List[ChatDocument]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class TextGenerationStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["text-generation"] = "text-generation" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class CitationGenerationStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["citation-generation"] = "citation-generation" citations: typing.List[ChatCitation] if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolCallsGenerationStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["tool-calls-generation"] = "tool-calls-generation" text: typing.Optional[str] = None tool_calls: typing.List[ToolCall] if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class StreamEndStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["stream-end"] = "stream-end" finish_reason: ChatStreamEndEventFinishReason response: NonStreamedChatResponse if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolCallsChunkStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["tool-calls-chunk"] = "tool-calls-chunk" tool_call_delta: ToolCallDelta text: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class DebugStreamedChatResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ event_type: typing.Literal["debug"] = "debug" prompt: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow StreamedChatResponse = typing_extensions.Annotated[ typing.Union[ StreamStartStreamedChatResponse, SearchQueriesGenerationStreamedChatResponse, SearchResultsStreamedChatResponse, TextGenerationStreamedChatResponse, CitationGenerationStreamedChatResponse, ToolCallsGenerationStreamedChatResponse, StreamEndStreamedChatResponse, ToolCallsChunkStreamedChatResponse, DebugStreamedChatResponse, ], UnionMetadata(discriminant="event_type"), ] ================================================ FILE: src/cohere/types/summarize_request_extractiveness.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing SummarizeRequestExtractiveness = typing.Union[typing.Literal["low", "medium", "high"], typing.Any] ================================================ FILE: src/cohere/types/summarize_request_format.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing SummarizeRequestFormat = typing.Union[typing.Literal["paragraph", "bullets"], typing.Any] ================================================ FILE: src/cohere/types/summarize_request_length.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing SummarizeRequestLength = typing.Union[typing.Literal["short", "medium", "long"], typing.Any] ================================================ FILE: src/cohere/types/summarize_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta class SummarizeResponse(UncheckedBaseModel): id: typing.Optional[str] = pydantic.Field(default=None) """ Generated ID for the summary """ summary: typing.Optional[str] = pydantic.Field(default=None) """ Generated summary for the text """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/system_message_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .system_message_v2content import SystemMessageV2Content class SystemMessageV2(UncheckedBaseModel): """ A message from the system. """ content: SystemMessageV2Content if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/system_message_v2content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .system_message_v2content_one_item import SystemMessageV2ContentOneItem SystemMessageV2Content = typing.Union[str, typing.List[SystemMessageV2ContentOneItem]] ================================================ FILE: src/cohere/types/system_message_v2content_one_item.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class TextSystemMessageV2ContentOneItem(UncheckedBaseModel): type: typing.Literal["text"] = "text" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow SystemMessageV2ContentOneItem = TextSystemMessageV2ContentOneItem ================================================ FILE: src/cohere/types/thinking.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .thinking_type import ThinkingType class Thinking(UncheckedBaseModel): """ Configuration for [reasoning features](https://docs.cohere.com/docs/reasoning). """ type: ThinkingType = pydantic.Field() """ Reasoning is enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`. """ token_budget: typing.Optional[int] = pydantic.Field(default=None) """ The maximum number of tokens the model can use for thinking, which must be set to a positive integer. The model will stop thinking if it reaches the thinking token budget and will proceed with the response. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/thinking_type.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing ThinkingType = typing.Union[typing.Literal["enabled", "disabled"], typing.Any] ================================================ FILE: src/cohere/types/tokenize_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .api_meta import ApiMeta class TokenizeResponse(UncheckedBaseModel): tokens: typing.List[int] = pydantic.Field() """ An array of tokens, where each token is an integer. """ token_strings: typing.List[str] meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_parameter_definitions_value import ToolParameterDefinitionsValue class Tool(UncheckedBaseModel): name: str = pydantic.Field() """ The name of the tool to be called. Valid names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. """ description: str = pydantic.Field() """ The description of what the tool does, the model uses the description to choose when and how to call the function. """ parameter_definitions: typing.Optional[typing.Dict[str, ToolParameterDefinitionsValue]] = pydantic.Field( default=None ) """ The input parameters of the tool. Accepts a dictionary where the key is the name of the parameter and the value is the parameter spec. Valid parameter names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. ``` { "my_param": { "description": , "type": , // any python data type, such as 'str', 'bool' "required": } } ``` """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_call.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ToolCall(UncheckedBaseModel): """ Contains the tool calls generated by the model. Use it to invoke your tools. """ name: str = pydantic.Field() """ Name of the tool to call. """ parameters: typing.Dict[str, typing.Any] = pydantic.Field() """ The name and value of the parameters to use when invoking a tool. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_call_delta.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ToolCallDelta(UncheckedBaseModel): """ Contains the chunk of the tool call generation in the stream. """ name: typing.Optional[str] = pydantic.Field(default=None) """ Name of the tool call """ index: typing.Optional[float] = pydantic.Field(default=None) """ Index of the tool call generated """ parameters: typing.Optional[str] = pydantic.Field(default=None) """ Chunk of the tool parameters """ text: typing.Optional[str] = pydantic.Field(default=None) """ Chunk of the tool plan text """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_call_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_call_v2function import ToolCallV2Function class ToolCallV2(UncheckedBaseModel): """ An array of tool calls to be made. """ id: str type: typing.Literal["function"] = "function" function: typing.Optional[ToolCallV2Function] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_call_v2function.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ToolCallV2Function(UncheckedBaseModel): name: typing.Optional[str] = None arguments: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_content.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from .document import Document class TextToolContent(UncheckedBaseModel): """ A content block which contains information about the content of a tool result """ type: typing.Literal["text"] = "text" text: str if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class DocumentToolContent(UncheckedBaseModel): """ A content block which contains information about the content of a tool result """ type: typing.Literal["document"] = "document" document: Document if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ToolContent = typing_extensions.Annotated[ typing.Union[TextToolContent, DocumentToolContent], UnionMetadata(discriminant="type") ] ================================================ FILE: src/cohere/types/tool_message_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_message_v2content import ToolMessageV2Content class ToolMessageV2(UncheckedBaseModel): """ A message with Tool outputs. """ tool_call_id: str = pydantic.Field() """ The id of the associated tool call that has provided the given content """ content: ToolMessageV2Content = pydantic.Field() """ Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_message_v2content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .tool_content import ToolContent ToolMessageV2Content = typing.Union[str, typing.List[ToolContent]] ================================================ FILE: src/cohere/types/tool_parameter_definitions_value.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ToolParameterDefinitionsValue(UncheckedBaseModel): description: typing.Optional[str] = pydantic.Field(default=None) """ The description of the parameter. """ type: str = pydantic.Field() """ The type of the parameter. Must be a valid Python type. """ required: typing.Optional[bool] = pydantic.Field(default=None) """ Denotes whether the parameter is always present (required) or not. Defaults to not required. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_result.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_call import ToolCall class ToolResult(UncheckedBaseModel): call: ToolCall outputs: typing.List[typing.Dict[str, typing.Any]] if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .tool_v2function import ToolV2Function class ToolV2(UncheckedBaseModel): type: typing.Literal["function"] = "function" function: typing.Optional[ToolV2Function] = pydantic.Field(default=None) """ The function to be executed. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/tool_v2function.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class ToolV2Function(UncheckedBaseModel): """ The function to be executed. """ name: str = pydantic.Field() """ The name of the function. """ description: typing.Optional[str] = pydantic.Field(default=None) """ The description of the function. """ parameters: typing.Dict[str, typing.Any] = pydantic.Field() """ The parameters of the function as a JSON schema. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/update_connector_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .connector import Connector class UpdateConnectorResponse(UncheckedBaseModel): connector: Connector if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/usage.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .usage_billed_units import UsageBilledUnits from .usage_tokens import UsageTokens class Usage(UncheckedBaseModel): billed_units: typing.Optional[UsageBilledUnits] = None tokens: typing.Optional[UsageTokens] = None cached_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of prompt tokens that hit the inference cache. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/usage_billed_units.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class UsageBilledUnits(UncheckedBaseModel): input_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed input tokens. """ output_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed output tokens. """ search_units: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed search units. """ classifications: typing.Optional[float] = pydantic.Field(default=None) """ The number of billed classifications units. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/usage_tokens.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel class UsageTokens(UncheckedBaseModel): input_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of tokens used as input to the model. """ output_tokens: typing.Optional[float] = pydantic.Field(default=None) """ The number of tokens produced by the model. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/user_message_v2.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ..core.pydantic_utilities import IS_PYDANTIC_V2 from ..core.unchecked_base_model import UncheckedBaseModel from .user_message_v2content import UserMessageV2Content class UserMessageV2(UncheckedBaseModel): """ A message from the user. """ content: UserMessageV2Content = pydantic.Field() """ The content of the message. This can be a string or a list of content blocks. If a string is provided, it will be treated as a text content block. """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/types/user_message_v2content.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from .content import Content UserMessageV2Content = typing.Union[str, typing.List[Content]] ================================================ FILE: src/cohere/utils.py ================================================ import asyncio import csv import json import time import typing from typing import Optional import requests from fastavro import parse_schema, reader, writer from . import EmbedResponse, EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse, ApiMeta, \ EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset from .datasets import DatasetsCreateResponse, DatasetsGetResponse from .overrides import get_fields # Note: utils.py does NOT call run_overrides() itself - that's done in client.py # which imports utils.py. This ensures overrides are applied when client is used. def get_terminal_states(): return get_success_states() | get_failed_states() def get_success_states(): return {"complete", "validated"} def get_failed_states(): return {"unknown", "failed", "skipped", "cancelled", "failed"} def get_id( awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr( getattr(awaitable, "dataset", None), "id", None) def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None) def get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \ typing.Union[ EmbedJob, DatasetsGetResponse]: if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": return cohere.embed_jobs.get(id=get_id(awaitable)) elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": return cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \ typing.Union[ EmbedJob, DatasetsGetResponse]: if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": return await cohere.embed_jobs.get(id=get_id(awaitable)) elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": return await cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]: if isinstance(job, EmbedJob): return f"Embed job {job.job_id} failed with status {job.status}" elif isinstance(job, DatasetsGetResponse): return f"Dataset creation failed with status {job.dataset.validation_status} and error : {job.dataset.validation_error}" return None @typing.overload def wait( cohere: typing.Any, awaitable: CreateEmbedJobResponse, timeout: Optional[float] = None, interval: float = 10, ) -> EmbedJob: ... @typing.overload def wait( cohere: typing.Any, awaitable: DatasetsCreateResponse, timeout: Optional[float] = None, interval: float = 10, ) -> DatasetsGetResponse: ... def wait( cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 2, ) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states() job = get_job(cohere, awaitable) while get_validation_status(job) not in terminal_states: if timeout is not None and time.time() - start_time > timeout: raise TimeoutError(f"wait timed out after {timeout} seconds") time.sleep(interval) print("...") job = get_job(cohere, awaitable) if get_validation_status(job) in failed_states: raise Exception(get_failure_reason(job)) return job @typing.overload async def async_wait( cohere: typing.Any, awaitable: CreateEmbedJobResponse, timeout: Optional[float] = None, interval: float = 10, ) -> EmbedJob: ... @typing.overload async def async_wait( cohere: typing.Any, awaitable: DatasetsCreateResponse, timeout: Optional[float] = None, interval: float = 10, ) -> DatasetsGetResponse: ... async def async_wait( cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 10, ) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states() job = await async_get_job(cohere, awaitable) while get_validation_status(job) not in terminal_states: if timeout is not None and time.time() - start_time > timeout: raise TimeoutError(f"wait timed out after {timeout} seconds") await asyncio.sleep(interval) print("...") job = await async_get_job(cohere, awaitable) if get_validation_status(job) in failed_states: raise Exception(get_failure_reason(job)) return job def sum_fields_if_not_none(obj: typing.Any, field: str) -> Optional[int]: non_none = [getattr(obj, field) for obj in obj if getattr(obj, field) is not None] return sum(non_none) if non_none else None def merge_meta_field(metas: typing.List[ApiMeta]) -> ApiMeta: api_version = metas[0].api_version if metas else None billed_units = [meta.billed_units for meta in metas] input_tokens = sum_fields_if_not_none(billed_units, "input_tokens") output_tokens = sum_fields_if_not_none(billed_units, "output_tokens") search_units = sum_fields_if_not_none(billed_units, "search_units") classifications = sum_fields_if_not_none(billed_units, "classifications") warnings = {warning for meta in metas if meta.warnings for warning in meta.warnings} return ApiMeta( api_version=api_version, billed_units=ApiMetaBilledUnits( input_tokens=input_tokens, output_tokens=output_tokens, search_units=search_units, classifications=classifications ), warnings=list(warnings) ) def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedResponse: meta = merge_meta_field([response.meta for response in responses if response.meta]) response_id = ", ".join(response.id for response in responses) texts = [ text for response in responses if response.texts is not None for text in response.texts ] if responses[0].response_type == "embeddings_floats": embeddings_floats = typing.cast(typing.List[EmbeddingsFloatsEmbedResponse], responses) embeddings = [ embedding for embeddings_floats in embeddings_floats for embedding in embeddings_floats.embeddings ] return EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id=response_id, texts=texts, embeddings=embeddings, meta=meta ) else: embeddings_type = typing.cast(typing.List[EmbeddingsByTypeEmbedResponse], responses) embeddings_by_type = [ response.embeddings for response in embeddings_type ] # only get set keys from the pydantic model (i.e. exclude fields that are set to 'None') fields = [x for x in get_fields(embeddings_type[0].embeddings) if getattr(embeddings_type[0].embeddings, x) is not None] merged_dicts = { field: [ embedding for embedding_by_type in embeddings_by_type for embedding in getattr(embedding_by_type, field) ] for field in fields } embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts) return EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id=response_id, embeddings=embeddings_by_type_merged, texts=texts, meta=meta ) supported_formats = ["jsonl", "csv", "avro"] def save_avro(dataset: Dataset, filepath: str): if not dataset.schema_: raise ValueError("Dataset does not have a schema") schema = parse_schema(json.loads(dataset.schema_)) with open(filepath, "wb") as outfile: writer(outfile, schema, dataset_generator(dataset)) def save_jsonl(dataset: Dataset, filepath: str): with open(filepath, "w") as outfile: for data in dataset_generator(dataset): json.dump(data, outfile) outfile.write("\n") def save_csv(dataset: Dataset, filepath: str): with open(filepath, "w") as outfile: for i, data in enumerate(dataset_generator(dataset)): if i == 0: writer = csv.DictWriter(outfile, fieldnames=list(data.keys())) writer.writeheader() writer.writerow(data) def dataset_generator(dataset: Dataset): if not dataset.dataset_parts: raise ValueError("Dataset does not have dataset_parts") for part in dataset.dataset_parts: if not part.url: raise ValueError("Dataset part does not have a url") resp = requests.get(part.url, stream=True) for record in reader(resp.raw): # type: ignore yield record class SdkUtils: @staticmethod def save_dataset(dataset: Dataset, filepath: str, format: typing.Literal["jsonl", "csv", "avro"] = "jsonl"): if format == "jsonl": return save_jsonl(dataset, filepath) if format == "csv": return save_csv(dataset, filepath) if format == "avro": return save_avro(dataset, filepath) raise Exception(f"unsupported format must be one of : {supported_formats}") class SyncSdkUtils(SdkUtils): pass class AsyncSdkUtils(SdkUtils): pass ================================================ FILE: src/cohere/v2/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .types import ( CitationEndV2ChatStreamResponse, CitationStartV2ChatStreamResponse, ContentDeltaV2ChatStreamResponse, ContentEndV2ChatStreamResponse, ContentStartV2ChatStreamResponse, DebugV2ChatStreamResponse, MessageEndV2ChatStreamResponse, MessageStartV2ChatStreamResponse, ToolCallDeltaV2ChatStreamResponse, ToolCallEndV2ChatStreamResponse, ToolCallStartV2ChatStreamResponse, ToolPlanDeltaV2ChatStreamResponse, V2ChatRequestDocumentsItem, V2ChatRequestSafetyMode, V2ChatRequestToolChoice, V2ChatResponse, V2ChatStreamRequestDocumentsItem, V2ChatStreamRequestSafetyMode, V2ChatStreamRequestToolChoice, V2ChatStreamResponse, V2EmbedRequestTruncate, V2RerankResponse, V2RerankResponseResultsItem, ) _dynamic_imports: typing.Dict[str, str] = { "CitationEndV2ChatStreamResponse": ".types", "CitationStartV2ChatStreamResponse": ".types", "ContentDeltaV2ChatStreamResponse": ".types", "ContentEndV2ChatStreamResponse": ".types", "ContentStartV2ChatStreamResponse": ".types", "DebugV2ChatStreamResponse": ".types", "MessageEndV2ChatStreamResponse": ".types", "MessageStartV2ChatStreamResponse": ".types", "ToolCallDeltaV2ChatStreamResponse": ".types", "ToolCallEndV2ChatStreamResponse": ".types", "ToolCallStartV2ChatStreamResponse": ".types", "ToolPlanDeltaV2ChatStreamResponse": ".types", "V2ChatRequestDocumentsItem": ".types", "V2ChatRequestSafetyMode": ".types", "V2ChatRequestToolChoice": ".types", "V2ChatResponse": ".types", "V2ChatStreamRequestDocumentsItem": ".types", "V2ChatStreamRequestSafetyMode": ".types", "V2ChatStreamRequestToolChoice": ".types", "V2ChatStreamResponse": ".types", "V2EmbedRequestTruncate": ".types", "V2RerankResponse": ".types", "V2RerankResponseResultsItem": ".types", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "CitationEndV2ChatStreamResponse", "CitationStartV2ChatStreamResponse", "ContentDeltaV2ChatStreamResponse", "ContentEndV2ChatStreamResponse", "ContentStartV2ChatStreamResponse", "DebugV2ChatStreamResponse", "MessageEndV2ChatStreamResponse", "MessageStartV2ChatStreamResponse", "ToolCallDeltaV2ChatStreamResponse", "ToolCallEndV2ChatStreamResponse", "ToolCallStartV2ChatStreamResponse", "ToolPlanDeltaV2ChatStreamResponse", "V2ChatRequestDocumentsItem", "V2ChatRequestSafetyMode", "V2ChatRequestToolChoice", "V2ChatResponse", "V2ChatStreamRequestDocumentsItem", "V2ChatStreamRequestSafetyMode", "V2ChatStreamRequestToolChoice", "V2ChatStreamResponse", "V2EmbedRequestTruncate", "V2RerankResponse", "V2RerankResponseResultsItem", ] ================================================ FILE: src/cohere/v2/client.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.request_options import RequestOptions from ..types.chat_messages import ChatMessages from ..types.citation_options import CitationOptions from ..types.embed_by_type_response import EmbedByTypeResponse from ..types.embed_input import EmbedInput from ..types.embed_input_type import EmbedInputType from ..types.embedding_type import EmbeddingType from ..types.response_format_v2 import ResponseFormatV2 from ..types.thinking import Thinking from ..types.tool_v2 import ToolV2 from .raw_client import AsyncRawV2Client, RawV2Client from .types.v2chat_request_documents_item import V2ChatRequestDocumentsItem from .types.v2chat_request_safety_mode import V2ChatRequestSafetyMode from .types.v2chat_request_tool_choice import V2ChatRequestToolChoice from .types.v2chat_response import V2ChatResponse from .types.v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem from .types.v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode from .types.v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice from .types.v2chat_stream_response import V2ChatStreamResponse from .types.v2embed_request_truncate import V2EmbedRequestTruncate from .types.v2rerank_response import V2RerankResponse # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class V2Client: def __init__(self, *, client_wrapper: SyncClientWrapper): self._raw_client = RawV2Client(client_wrapper=client_wrapper) @property def with_raw_response(self) -> RawV2Client: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- RawV2Client """ return self._raw_client def chat_stream( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[V2ChatStreamResponse]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatStreamRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[V2ChatStreamResponse] Examples -------- from cohere import Client, UserChatMessageV2 client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) response = client.v2.chat_stream( model="command-a-03-2025", messages=[ UserChatMessageV2( content="Tell me about LLMs", ) ], ) for chunk in response: yield chunk """ with self._raw_client.chat_stream( model=model, messages=messages, tools=tools, strict_tools=strict_tools, documents=documents, citation_options=citation_options, response_format=response_format, safety_mode=safety_mode, max_tokens=max_tokens, stop_sequences=stop_sequences, temperature=temperature, seed=seed, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, k=k, p=p, logprobs=logprobs, tool_choice=tool_choice, thinking=thinking, priority=priority, request_options=request_options, ) as r: yield from r.data def chat( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> V2ChatResponse: """ Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- V2ChatResponse Examples -------- from cohere import Client, UserChatMessageV2 client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.v2.chat( model="command-a-03-2025", messages=[ UserChatMessageV2( content="Tell me about LLMs", ) ], ) """ _response = self._raw_client.chat( model=model, messages=messages, tools=tools, strict_tools=strict_tools, documents=documents, citation_options=citation_options, response_format=response_format, safety_mode=safety_mode, max_tokens=max_tokens, stop_sequences=stop_sequences, temperature=temperature, seed=seed, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, k=k, p=p, logprobs=logprobs, tool_choice=tool_choice, thinking=thinking, priority=priority, request_options=request_options, ) return _response.data def embed( self, *, model: str, input_type: EmbedInputType, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT, max_tokens: typing.Optional[int] = OMIT, output_dimension: typing.Optional[int] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> EmbedByTypeResponse: """ This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- model : str ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : EmbedInputType texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Image embeddings are supported with Embed v3.0 and newer models. inputs : typing.Optional[typing.Sequence[EmbedInput]] An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components. max_tokens : typing.Optional[int] The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter. output_dimension : typing.Optional[int] The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models. Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[V2EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedByTypeResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.v2.embed( texts=["hello", "goodbye"], model="embed-v4.0", input_type="classification", embedding_types=["float"], ) """ _response = self._raw_client.embed( model=model, input_type=input_type, texts=texts, images=images, inputs=inputs, max_tokens=max_tokens, output_dimension=output_dimension, embedding_types=embedding_types, truncate=truncate, priority=priority, request_options=request_options, ) return _response.data def rerank( self, *, model: str, query: str, documents: typing.Sequence[str], top_n: typing.Optional[int] = OMIT, max_tokens_per_doc: typing.Optional[int] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> V2RerankResponse: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- model : str The identifier of the model to use, eg `rerank-v3.5`. query : str The search query documents : typing.Sequence[str] A list of texts that will be compared to the `query`. For optimal performance we recommend against sending more than 1,000 documents in a single request. **Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`. **Note**: structured data should be formatted as YAML strings for best performance. top_n : typing.Optional[int] Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned. max_tokens_per_doc : typing.Optional[int] Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- V2RerankResponse OK Examples -------- from cohere import Client client = Client( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) client.v2.rerank( documents=[ "Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) """ _response = self._raw_client.rerank( model=model, query=query, documents=documents, top_n=top_n, max_tokens_per_doc=max_tokens_per_doc, priority=priority, request_options=request_options, ) return _response.data class AsyncV2Client: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._raw_client = AsyncRawV2Client(client_wrapper=client_wrapper) @property def with_raw_response(self) -> AsyncRawV2Client: """ Retrieves a raw implementation of this client that returns raw responses. Returns ------- AsyncRawV2Client """ return self._raw_client async def chat_stream( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[V2ChatStreamResponse]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatStreamRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[V2ChatStreamResponse] Examples -------- import asyncio from cohere import AsyncClient, UserChatMessageV2 client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: response = await client.v2.chat_stream( model="command-a-03-2025", messages=[ UserChatMessageV2( content="Tell me about LLMs", ) ], ) async for chunk in response: yield chunk asyncio.run(main()) """ async with self._raw_client.chat_stream( model=model, messages=messages, tools=tools, strict_tools=strict_tools, documents=documents, citation_options=citation_options, response_format=response_format, safety_mode=safety_mode, max_tokens=max_tokens, stop_sequences=stop_sequences, temperature=temperature, seed=seed, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, k=k, p=p, logprobs=logprobs, tool_choice=tool_choice, thinking=thinking, priority=priority, request_options=request_options, ) as r: async for _chunk in r.data: yield _chunk async def chat( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> V2ChatResponse: """ Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- V2ChatResponse Examples -------- import asyncio from cohere import AsyncClient, UserChatMessageV2 client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.v2.chat( model="command-a-03-2025", messages=[ UserChatMessageV2( content="Tell me about LLMs", ) ], ) asyncio.run(main()) """ _response = await self._raw_client.chat( model=model, messages=messages, tools=tools, strict_tools=strict_tools, documents=documents, citation_options=citation_options, response_format=response_format, safety_mode=safety_mode, max_tokens=max_tokens, stop_sequences=stop_sequences, temperature=temperature, seed=seed, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, k=k, p=p, logprobs=logprobs, tool_choice=tool_choice, thinking=thinking, priority=priority, request_options=request_options, ) return _response.data async def embed( self, *, model: str, input_type: EmbedInputType, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT, max_tokens: typing.Optional[int] = OMIT, output_dimension: typing.Optional[int] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> EmbedByTypeResponse: """ This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- model : str ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : EmbedInputType texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Image embeddings are supported with Embed v3.0 and newer models. inputs : typing.Optional[typing.Sequence[EmbedInput]] An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components. max_tokens : typing.Optional[int] The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter. output_dimension : typing.Optional[int] The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models. Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[V2EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- EmbedByTypeResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.v2.embed( texts=["hello", "goodbye"], model="embed-v4.0", input_type="classification", embedding_types=["float"], ) asyncio.run(main()) """ _response = await self._raw_client.embed( model=model, input_type=input_type, texts=texts, images=images, inputs=inputs, max_tokens=max_tokens, output_dimension=output_dimension, embedding_types=embedding_types, truncate=truncate, priority=priority, request_options=request_options, ) return _response.data async def rerank( self, *, model: str, query: str, documents: typing.Sequence[str], top_n: typing.Optional[int] = OMIT, max_tokens_per_doc: typing.Optional[int] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> V2RerankResponse: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- model : str The identifier of the model to use, eg `rerank-v3.5`. query : str The search query documents : typing.Sequence[str] A list of texts that will be compared to the `query`. For optimal performance we recommend against sending more than 1,000 documents in a single request. **Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`. **Note**: structured data should be formatted as YAML strings for best performance. top_n : typing.Optional[int] Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned. max_tokens_per_doc : typing.Optional[int] Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- V2RerankResponse OK Examples -------- import asyncio from cohere import AsyncClient client = AsyncClient( client_name="YOUR_CLIENT_NAME", token="YOUR_TOKEN", ) async def main() -> None: await client.v2.rerank( documents=[ "Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", ], query="What is the capital of the United States?", top_n=3, model="rerank-v4.0-pro", ) asyncio.run(main()) """ _response = await self._raw_client.rerank( model=model, query=query, documents=documents, top_n=top_n, max_tokens_per_doc=max_tokens_per_doc, priority=priority, request_options=request_options, ) return _response.data ================================================ FILE: src/cohere/v2/raw_client.py ================================================ # This file was auto-generated by Fern from our API Definition. import contextlib import typing from json.decoder import JSONDecodeError from logging import error, warning from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse from ..core.http_sse._api import EventSource from ..core.parse_error import ParsingError from ..core.pydantic_utilities import parse_sse_obj from ..core.request_options import RequestOptions from ..core.serialization import convert_and_respect_annotation_metadata from ..core.unchecked_base_model import construct_type from ..errors.bad_request_error import BadRequestError from ..errors.client_closed_request_error import ClientClosedRequestError from ..errors.forbidden_error import ForbiddenError from ..errors.gateway_timeout_error import GatewayTimeoutError from ..errors.internal_server_error import InternalServerError from ..errors.invalid_token_error import InvalidTokenError from ..errors.not_found_error import NotFoundError from ..errors.not_implemented_error import NotImplementedError from ..errors.service_unavailable_error import ServiceUnavailableError from ..errors.too_many_requests_error import TooManyRequestsError from ..errors.unauthorized_error import UnauthorizedError from ..errors.unprocessable_entity_error import UnprocessableEntityError from ..types.chat_messages import ChatMessages from ..types.citation_options import CitationOptions from ..types.embed_by_type_response import EmbedByTypeResponse from ..types.embed_input import EmbedInput from ..types.embed_input_type import EmbedInputType from ..types.embedding_type import EmbeddingType from ..types.response_format_v2 import ResponseFormatV2 from ..types.thinking import Thinking from ..types.tool_v2 import ToolV2 from .types.v2chat_request_documents_item import V2ChatRequestDocumentsItem from .types.v2chat_request_safety_mode import V2ChatRequestSafetyMode from .types.v2chat_request_tool_choice import V2ChatRequestToolChoice from .types.v2chat_response import V2ChatResponse from .types.v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem from .types.v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode from .types.v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice from .types.v2chat_stream_response import V2ChatStreamResponse from .types.v2embed_request_truncate import V2EmbedRequestTruncate from .types.v2rerank_response import V2RerankResponse from pydantic import ValidationError # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) class RawV2Client: def __init__(self, *, client_wrapper: SyncClientWrapper): self._client_wrapper = client_wrapper @contextlib.contextmanager def chat_stream( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.Iterator[HttpResponse[typing.Iterator[V2ChatStreamResponse]]]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatStreamRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.Iterator[HttpResponse[typing.Iterator[V2ChatStreamResponse]]] """ with self._client_wrapper.httpx_client.stream( "v2/chat", method="POST", json={ "model": model, "messages": convert_and_respect_annotation_metadata( object_=messages, annotation=ChatMessages, direction="write" ), "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[ToolV2], direction="write" ), "strict_tools": strict_tools, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write" ), "citation_options": convert_and_respect_annotation_metadata( object_=citation_options, annotation=CitationOptions, direction="write" ), "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormatV2, direction="write" ), "safety_mode": safety_mode, "max_tokens": max_tokens, "stop_sequences": stop_sequences, "temperature": temperature, "seed": seed, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "k": k, "p": p, "logprobs": logprobs, "tool_choice": tool_choice, "thinking": convert_and_respect_annotation_metadata( object_=thinking, annotation=Thinking, direction="write" ), "priority": priority, "stream": True, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) as _response: def _stream() -> HttpResponse[typing.Iterator[V2ChatStreamResponse]]: try: if 200 <= _response.status_code < 300: def _iter(): _event_source = EventSource(_response) for _sse in _event_source.iter_sse(): if _sse.data == "[DONE]": return try: yield typing.cast( V2ChatStreamResponse, parse_sse_obj( sse=_sse, type_=V2ChatStreamResponse, # type: ignore ), ) except JSONDecodeError as e: warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") except (TypeError, ValueError, KeyError, AttributeError) as e: warning( f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" ) except Exception as e: error( f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" ) return return HttpResponse(response=_response, data=_iter()) _response.read() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield _stream() def chat( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[V2ChatResponse]: """ Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[V2ChatResponse] """ _response = self._client_wrapper.httpx_client.request( "v2/chat", method="POST", json={ "model": model, "messages": convert_and_respect_annotation_metadata( object_=messages, annotation=ChatMessages, direction="write" ), "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[ToolV2], direction="write" ), "strict_tools": strict_tools, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write" ), "citation_options": convert_and_respect_annotation_metadata( object_=citation_options, annotation=CitationOptions, direction="write" ), "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormatV2, direction="write" ), "safety_mode": safety_mode, "max_tokens": max_tokens, "stop_sequences": stop_sequences, "temperature": temperature, "seed": seed, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "k": k, "p": p, "logprobs": logprobs, "tool_choice": tool_choice, "thinking": convert_and_respect_annotation_metadata( object_=thinking, annotation=Thinking, direction="write" ), "priority": priority, "stream": False, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( V2ChatResponse, construct_type( type_=V2ChatResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def embed( self, *, model: str, input_type: EmbedInputType, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT, max_tokens: typing.Optional[int] = OMIT, output_dimension: typing.Optional[int] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[EmbedByTypeResponse]: """ This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- model : str ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : EmbedInputType texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Image embeddings are supported with Embed v3.0 and newer models. inputs : typing.Optional[typing.Sequence[EmbedInput]] An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components. max_tokens : typing.Optional[int] The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter. output_dimension : typing.Optional[int] The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models. Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[V2EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[EmbedByTypeResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v2/embed", method="POST", json={ "texts": texts, "images": images, "model": model, "input_type": input_type, "inputs": convert_and_respect_annotation_metadata( object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write" ), "max_tokens": max_tokens, "output_dimension": output_dimension, "embedding_types": embedding_types, "truncate": truncate, "priority": priority, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedByTypeResponse, construct_type( type_=EmbedByTypeResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) def rerank( self, *, model: str, query: str, documents: typing.Sequence[str], top_n: typing.Optional[int] = OMIT, max_tokens_per_doc: typing.Optional[int] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> HttpResponse[V2RerankResponse]: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- model : str The identifier of the model to use, eg `rerank-v3.5`. query : str The search query documents : typing.Sequence[str] A list of texts that will be compared to the `query`. For optimal performance we recommend against sending more than 1,000 documents in a single request. **Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`. **Note**: structured data should be formatted as YAML strings for best performance. top_n : typing.Optional[int] Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned. max_tokens_per_doc : typing.Optional[int] Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- HttpResponse[V2RerankResponse] OK """ _response = self._client_wrapper.httpx_client.request( "v2/rerank", method="POST", json={ "model": model, "query": query, "documents": documents, "top_n": top_n, "max_tokens_per_doc": max_tokens_per_doc, "priority": priority, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( V2RerankResponse, construct_type( type_=V2RerankResponse, # type: ignore object_=_response.json(), ), ) return HttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawV2Client: def __init__(self, *, client_wrapper: AsyncClientWrapper): self._client_wrapper = client_wrapper @contextlib.asynccontextmanager async def chat_stream( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]]: """ Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatStreamRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Yields ------ typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]] """ async with self._client_wrapper.httpx_client.stream( "v2/chat", method="POST", json={ "model": model, "messages": convert_and_respect_annotation_metadata( object_=messages, annotation=ChatMessages, direction="write" ), "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[ToolV2], direction="write" ), "strict_tools": strict_tools, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write" ), "citation_options": convert_and_respect_annotation_metadata( object_=citation_options, annotation=CitationOptions, direction="write" ), "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormatV2, direction="write" ), "safety_mode": safety_mode, "max_tokens": max_tokens, "stop_sequences": stop_sequences, "temperature": temperature, "seed": seed, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "k": k, "p": p, "logprobs": logprobs, "tool_choice": tool_choice, "thinking": convert_and_respect_annotation_metadata( object_=thinking, annotation=Thinking, direction="write" ), "priority": priority, "stream": True, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) as _response: async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]: try: if 200 <= _response.status_code < 300: async def _iter(): _event_source = EventSource(_response) async for _sse in _event_source.aiter_sse(): if _sse.data == "[DONE]": return try: yield typing.cast( V2ChatStreamResponse, parse_sse_obj( sse=_sse, type_=V2ChatStreamResponse, # type: ignore ), ) except JSONDecodeError as e: warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") except (TypeError, ValueError, KeyError, AttributeError) as e: warning( f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" ) except Exception as e: error( f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" ) return return AsyncHttpResponse(response=_response, data=_iter()) await _response.aread() if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.text ) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e, ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) yield await _stream() async def chat( self, *, model: str, messages: ChatMessages, tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT, strict_tools: typing.Optional[bool] = OMIT, documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT, citation_options: typing.Optional[CitationOptions] = OMIT, response_format: typing.Optional[ResponseFormatV2] = OMIT, safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT, max_tokens: typing.Optional[int] = OMIT, stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT, temperature: typing.Optional[float] = OMIT, seed: typing.Optional[int] = OMIT, frequency_penalty: typing.Optional[float] = OMIT, presence_penalty: typing.Optional[float] = OMIT, k: typing.Optional[int] = OMIT, p: typing.Optional[float] = OMIT, logprobs: typing.Optional[bool] = OMIT, tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT, thinking: typing.Optional[Thinking] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[V2ChatResponse]: """ Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. Parameters ---------- model : str The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models). messages : ChatMessages tools : typing.Optional[typing.Sequence[ToolV2]] A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools. Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools). strict_tools : typing.Optional[bool] When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools). **Note**: The first few requests with a new set of tools will take longer to process. documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata. citation_options : typing.Optional[CitationOptions] response_format : typing.Optional[ResponseFormatV2] safety_mode : typing.Optional[V2ChatRequestSafetyMode] Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. When `OFF` is specified, the safety instruction will be omitted. Safety modes are not yet configurable in combination with `tools` and `documents` parameters. **Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release). **Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes. max_tokens : typing.Optional[int] The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models). **Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`. **Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit. stop_sequences : typing.Optional[typing.Sequence[str]] A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. temperature : typing.Optional[float] Defaults to `0.3`. A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. Randomness can be further maximized by increasing the value of the `p` parameter. seed : typing.Optional[int] If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. frequency_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. presence_penalty : typing.Optional[float] Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. k : typing.Optional[int] Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled. Defaults to `0`, min value of `0`, max value of `500`. p : typing.Optional[float] Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. Defaults to `0.75`. min value of `0.01`, max value of `0.99`. logprobs : typing.Optional[bool] Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response. tool_choice : typing.Optional[V2ChatRequestToolChoice] Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request. When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response. If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not. **Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer. thinking : typing.Optional[Thinking] priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[V2ChatResponse] """ _response = await self._client_wrapper.httpx_client.request( "v2/chat", method="POST", json={ "model": model, "messages": convert_and_respect_annotation_metadata( object_=messages, annotation=ChatMessages, direction="write" ), "tools": convert_and_respect_annotation_metadata( object_=tools, annotation=typing.Sequence[ToolV2], direction="write" ), "strict_tools": strict_tools, "documents": convert_and_respect_annotation_metadata( object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write" ), "citation_options": convert_and_respect_annotation_metadata( object_=citation_options, annotation=CitationOptions, direction="write" ), "response_format": convert_and_respect_annotation_metadata( object_=response_format, annotation=ResponseFormatV2, direction="write" ), "safety_mode": safety_mode, "max_tokens": max_tokens, "stop_sequences": stop_sequences, "temperature": temperature, "seed": seed, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "k": k, "p": p, "logprobs": logprobs, "tool_choice": tool_choice, "thinking": convert_and_respect_annotation_metadata( object_=thinking, annotation=Thinking, direction="write" ), "priority": priority, "stream": False, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( V2ChatResponse, construct_type( type_=V2ChatResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def embed( self, *, model: str, input_type: EmbedInputType, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT, max_tokens: typing.Optional[int] = OMIT, output_dimension: typing.Optional[int] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[EmbedByTypeResponse]: """ This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page. If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search). Parameters ---------- model : str ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). input_type : EmbedInputType texts : typing.Optional[typing.Sequence[str]] An array of strings for the model to embed. Maximum number of texts per call is `96`. images : typing.Optional[typing.Sequence[str]] An array of image data URIs for the model to embed. Maximum number of images per call is `1`. The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB. Image embeddings are supported with Embed v3.0 and newer models. inputs : typing.Optional[typing.Sequence[EmbedInput]] An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components. max_tokens : typing.Optional[int] The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter. output_dimension : typing.Optional[int] The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models. Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`. embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] Specifies the types of embeddings you want to get back. Can be one or more of the following types. * `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models. * `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models. * `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models. * `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models. truncate : typing.Optional[V2EmbedRequestTruncate] One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[EmbedByTypeResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v2/embed", method="POST", json={ "texts": texts, "images": images, "model": model, "input_type": input_type, "inputs": convert_and_respect_annotation_metadata( object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write" ), "max_tokens": max_tokens, "output_dimension": output_dimension, "embedding_types": embedding_types, "truncate": truncate, "priority": priority, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( EmbedByTypeResponse, construct_type( type_=EmbedByTypeResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) async def rerank( self, *, model: str, query: str, documents: typing.Sequence[str], top_n: typing.Optional[int] = OMIT, max_tokens_per_doc: typing.Optional[int] = OMIT, priority: typing.Optional[int] = OMIT, request_options: typing.Optional[RequestOptions] = None, ) -> AsyncHttpResponse[V2RerankResponse]: """ This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. Parameters ---------- model : str The identifier of the model to use, eg `rerank-v3.5`. query : str The search query documents : typing.Sequence[str] A list of texts that will be compared to the `query`. For optimal performance we recommend against sending more than 1,000 documents in a single request. **Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`. **Note**: structured data should be formatted as YAML strings for best performance. top_n : typing.Optional[int] Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned. max_tokens_per_doc : typing.Optional[int] Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens. priority : typing.Optional[int] Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped. request_options : typing.Optional[RequestOptions] Request-specific configuration. Returns ------- AsyncHttpResponse[V2RerankResponse] OK """ _response = await self._client_wrapper.httpx_client.request( "v2/rerank", method="POST", json={ "model": model, "query": query, "documents": documents, "top_n": top_n, "max_tokens_per_doc": max_tokens_per_doc, "priority": priority, }, headers={ "content-type": "application/json", }, request_options=request_options, omit=OMIT, ) try: if 200 <= _response.status_code < 300: _data = typing.cast( V2RerankResponse, construct_type( type_=V2RerankResponse, # type: ignore object_=_response.json(), ), ) return AsyncHttpResponse(response=_response, data=_data) if _response.status_code == 400: raise BadRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 401: raise UnauthorizedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 403: raise ForbiddenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 404: raise NotFoundError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 422: raise UnprocessableEntityError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 429: raise TooManyRequestsError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 498: raise InvalidTokenError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 499: raise ClientClosedRequestError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 500: raise InternalServerError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 501: raise NotImplementedError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 503: raise ServiceUnavailableError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) if _response.status_code == 504: raise GatewayTimeoutError( headers=dict(_response.headers), body=typing.cast( typing.Any, construct_type( type_=typing.Any, # type: ignore object_=_response.json(), ), ), ) _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) except ValidationError as e: raise ParsingError( status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e ) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) ================================================ FILE: src/cohere/v2/types/__init__.py ================================================ # This file was auto-generated by Fern from our API Definition. # isort: skip_file import typing from importlib import import_module if typing.TYPE_CHECKING: from .v2chat_request_documents_item import V2ChatRequestDocumentsItem from .v2chat_request_safety_mode import V2ChatRequestSafetyMode from .v2chat_request_tool_choice import V2ChatRequestToolChoice from .v2chat_response import V2ChatResponse from .v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem from .v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode from .v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice from .v2chat_stream_response import ( CitationEndV2ChatStreamResponse, CitationStartV2ChatStreamResponse, ContentDeltaV2ChatStreamResponse, ContentEndV2ChatStreamResponse, ContentStartV2ChatStreamResponse, DebugV2ChatStreamResponse, MessageEndV2ChatStreamResponse, MessageStartV2ChatStreamResponse, ToolCallDeltaV2ChatStreamResponse, ToolCallEndV2ChatStreamResponse, ToolCallStartV2ChatStreamResponse, ToolPlanDeltaV2ChatStreamResponse, V2ChatStreamResponse, ) from .v2embed_request_truncate import V2EmbedRequestTruncate from .v2rerank_response import V2RerankResponse from .v2rerank_response_results_item import V2RerankResponseResultsItem _dynamic_imports: typing.Dict[str, str] = { "CitationEndV2ChatStreamResponse": ".v2chat_stream_response", "CitationStartV2ChatStreamResponse": ".v2chat_stream_response", "ContentDeltaV2ChatStreamResponse": ".v2chat_stream_response", "ContentEndV2ChatStreamResponse": ".v2chat_stream_response", "ContentStartV2ChatStreamResponse": ".v2chat_stream_response", "DebugV2ChatStreamResponse": ".v2chat_stream_response", "MessageEndV2ChatStreamResponse": ".v2chat_stream_response", "MessageStartV2ChatStreamResponse": ".v2chat_stream_response", "ToolCallDeltaV2ChatStreamResponse": ".v2chat_stream_response", "ToolCallEndV2ChatStreamResponse": ".v2chat_stream_response", "ToolCallStartV2ChatStreamResponse": ".v2chat_stream_response", "ToolPlanDeltaV2ChatStreamResponse": ".v2chat_stream_response", "V2ChatRequestDocumentsItem": ".v2chat_request_documents_item", "V2ChatRequestSafetyMode": ".v2chat_request_safety_mode", "V2ChatRequestToolChoice": ".v2chat_request_tool_choice", "V2ChatResponse": ".v2chat_response", "V2ChatStreamRequestDocumentsItem": ".v2chat_stream_request_documents_item", "V2ChatStreamRequestSafetyMode": ".v2chat_stream_request_safety_mode", "V2ChatStreamRequestToolChoice": ".v2chat_stream_request_tool_choice", "V2ChatStreamResponse": ".v2chat_stream_response", "V2EmbedRequestTruncate": ".v2embed_request_truncate", "V2RerankResponse": ".v2rerank_response", "V2RerankResponseResultsItem": ".v2rerank_response_results_item", } def __getattr__(attr_name: str) -> typing.Any: module_name = _dynamic_imports.get(attr_name) if module_name is None: raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}") try: module = import_module(module_name, __package__) if module_name == f".{attr_name}": return module else: return getattr(module, attr_name) except ImportError as e: raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e except AttributeError as e: raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e def __dir__(): lazy_attrs = list(_dynamic_imports.keys()) return sorted(lazy_attrs) __all__ = [ "CitationEndV2ChatStreamResponse", "CitationStartV2ChatStreamResponse", "ContentDeltaV2ChatStreamResponse", "ContentEndV2ChatStreamResponse", "ContentStartV2ChatStreamResponse", "DebugV2ChatStreamResponse", "MessageEndV2ChatStreamResponse", "MessageStartV2ChatStreamResponse", "ToolCallDeltaV2ChatStreamResponse", "ToolCallEndV2ChatStreamResponse", "ToolCallStartV2ChatStreamResponse", "ToolPlanDeltaV2ChatStreamResponse", "V2ChatRequestDocumentsItem", "V2ChatRequestSafetyMode", "V2ChatRequestToolChoice", "V2ChatResponse", "V2ChatStreamRequestDocumentsItem", "V2ChatStreamRequestSafetyMode", "V2ChatStreamRequestToolChoice", "V2ChatStreamResponse", "V2EmbedRequestTruncate", "V2RerankResponse", "V2RerankResponseResultsItem", ] ================================================ FILE: src/cohere/v2/types/v2chat_request_documents_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ...types.document import Document V2ChatRequestDocumentsItem = typing.Union[str, Document] ================================================ FILE: src/cohere/v2/types/v2chat_request_safety_mode.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing V2ChatRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "OFF"], typing.Any] ================================================ FILE: src/cohere/v2/types/v2chat_request_tool_choice.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing V2ChatRequestToolChoice = typing.Union[typing.Literal["REQUIRED", "NONE"], typing.Any] ================================================ FILE: src/cohere/v2/types/v2chat_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from ...types.assistant_message_response import AssistantMessageResponse from ...types.chat_finish_reason import ChatFinishReason from ...types.logprob_item import LogprobItem from ...types.usage import Usage class V2ChatResponse(UncheckedBaseModel): id: str = pydantic.Field() """ Unique identifier for the generated reply. Useful for submitting feedback. """ finish_reason: ChatFinishReason message: AssistantMessageResponse usage: typing.Optional[Usage] = None logprobs: typing.Optional[typing.List[LogprobItem]] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/v2/types/v2chat_stream_request_documents_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing from ...types.document import Document V2ChatStreamRequestDocumentsItem = typing.Union[str, Document] ================================================ FILE: src/cohere/v2/types/v2chat_stream_request_safety_mode.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing V2ChatStreamRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "OFF"], typing.Any] ================================================ FILE: src/cohere/v2/types/v2chat_stream_request_tool_choice.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing V2ChatStreamRequestToolChoice = typing.Union[typing.Literal["REQUIRED", "NONE"], typing.Any] ================================================ FILE: src/cohere/v2/types/v2chat_stream_response.py ================================================ # This file was auto-generated by Fern from our API Definition. from __future__ import annotations import typing import pydantic import typing_extensions from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel, UnionMetadata from ...types.chat_content_delta_event_delta import ChatContentDeltaEventDelta from ...types.chat_content_start_event_delta import ChatContentStartEventDelta from ...types.chat_message_end_event_delta import ChatMessageEndEventDelta from ...types.chat_message_start_event_delta import ChatMessageStartEventDelta from ...types.chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta from ...types.chat_tool_call_start_event_delta import ChatToolCallStartEventDelta from ...types.chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta from ...types.citation_start_event_delta import CitationStartEventDelta from ...types.logprob_item import LogprobItem class MessageStartV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["message-start"] = "message-start" id: typing.Optional[str] = None delta: typing.Optional[ChatMessageStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ContentStartV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["content-start"] = "content-start" index: typing.Optional[int] = None delta: typing.Optional[ChatContentStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ContentDeltaV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["content-delta"] = "content-delta" index: typing.Optional[int] = None delta: typing.Optional[ChatContentDeltaEventDelta] = None logprobs: typing.Optional[LogprobItem] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ContentEndV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["content-end"] = "content-end" index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolPlanDeltaV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["tool-plan-delta"] = "tool-plan-delta" delta: typing.Optional[ChatToolPlanDeltaEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolCallStartV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["tool-call-start"] = "tool-call-start" index: typing.Optional[int] = None delta: typing.Optional[ChatToolCallStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolCallDeltaV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["tool-call-delta"] = "tool-call-delta" index: typing.Optional[int] = None delta: typing.Optional[ChatToolCallDeltaEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class ToolCallEndV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["tool-call-end"] = "tool-call-end" index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class CitationStartV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["citation-start"] = "citation-start" index: typing.Optional[int] = None delta: typing.Optional[CitationStartEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class CitationEndV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["citation-end"] = "citation-end" index: typing.Optional[int] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class MessageEndV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["message-end"] = "message-end" id: typing.Optional[str] = None delta: typing.Optional[ChatMessageEndEventDelta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow class DebugV2ChatStreamResponse(UncheckedBaseModel): """ StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). """ type: typing.Literal["debug"] = "debug" prompt: typing.Optional[str] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow V2ChatStreamResponse = typing_extensions.Annotated[ typing.Union[ MessageStartV2ChatStreamResponse, ContentStartV2ChatStreamResponse, ContentDeltaV2ChatStreamResponse, ContentEndV2ChatStreamResponse, ToolPlanDeltaV2ChatStreamResponse, ToolCallStartV2ChatStreamResponse, ToolCallDeltaV2ChatStreamResponse, ToolCallEndV2ChatStreamResponse, CitationStartV2ChatStreamResponse, CitationEndV2ChatStreamResponse, MessageEndV2ChatStreamResponse, DebugV2ChatStreamResponse, ], UnionMetadata(discriminant="type"), ] ================================================ FILE: src/cohere/v2/types/v2embed_request_truncate.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing V2EmbedRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any] ================================================ FILE: src/cohere/v2/types/v2rerank_response.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel from ...types.api_meta import ApiMeta from .v2rerank_response_results_item import V2RerankResponseResultsItem class V2RerankResponse(UncheckedBaseModel): id: typing.Optional[str] = None results: typing.List[V2RerankResponseResultsItem] = pydantic.Field() """ An ordered list of ranked documents """ meta: typing.Optional[ApiMeta] = None if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/v2/types/v2rerank_response_results_item.py ================================================ # This file was auto-generated by Fern from our API Definition. import typing import pydantic from ...core.pydantic_utilities import IS_PYDANTIC_V2 from ...core.unchecked_base_model import UncheckedBaseModel class V2RerankResponseResultsItem(UncheckedBaseModel): index: int = pydantic.Field() """ Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) """ relevance_score: float = pydantic.Field() """ Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 """ if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 else: class Config: smart_union = True extra = pydantic.Extra.allow ================================================ FILE: src/cohere/version.py ================================================ from importlib import metadata __version__ = metadata.version("cohere") ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/embed_job.jsonl ================================================ {"text": "The quick brown fox jumps over the lazy dog"} ================================================ FILE: tests/test_async_client.py ================================================ import os import unittest import cohere from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ ToolParameterDefinitionsValue, ToolResult, UserMessage, ChatbotMessage package_dir = os.path.dirname(os.path.abspath(__file__)) embed_job = os.path.join(package_dir, 'embed_job.jsonl') class TestClient(unittest.IsolatedAsyncioTestCase): co: cohere.AsyncClient def setUp(self) -> None: self.co = cohere.AsyncClient(timeout=10000) async def test_token_falls_back_on_env_variable(self) -> None: cohere.AsyncClient(api_key=None) cohere.AsyncClient(None) async def test_context_manager(self) -> None: async with cohere.AsyncClient(api_key="xxx") as client: self.assertIsNotNone(client) async def test_chat(self) -> None: chat = await self.co.chat( model="command-a-03-2025", chat_history=[ UserMessage( message="Who discovered gravity?"), ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", ) print(chat) async def test_chat_stream(self) -> None: stream = self.co.chat_stream( model="command-a-03-2025", chat_history=[ UserMessage( message="Who discovered gravity?"), ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", ) events = set() async for chat_event in stream: events.add(chat_event.event_type) if chat_event.event_type == "text-generation": print(chat_event.text) self.assertTrue("text-generation" in events) self.assertTrue("stream-start" in events) self.assertTrue("stream-end" in events) async def test_stream_equals_true(self) -> None: with self.assertRaises(ValueError): await self.co.chat( stream=True, # type: ignore message="What year was he born?", ) async def test_deprecated_fn(self) -> None: with self.assertRaises(ValueError): await self.co.check_api_key("dummy", dummy="dummy") # type: ignore async def test_moved_fn(self) -> None: with self.assertRaises(ValueError): await self.co.list_connectors("dummy", dummy="dummy") # type: ignore @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") async def test_embed(self) -> None: response = await self.co.embed( texts=['hello', 'goodbye'], model='embed-english-v3.0', input_type="classification" ) print(response) async def test_embed_batch_types(self) -> None: # batch more than 96 texts response = await self.co.embed( texts=['hello']*100, model='embed-english-v3.0', input_type="classification", embedding_types=["float", "int8", "uint8", "binary", "ubinary"] ) if response.response_type == "embeddings_by_type": self.assertEqual(len(response.texts or []), 100) self.assertEqual(len(response.embeddings.float_ or []), 100) self.assertEqual(len(response.embeddings.int8 or []), 100) self.assertEqual(len(response.embeddings.uint8 or []), 100) self.assertEqual(len(response.embeddings.binary or []), 100) self.assertEqual(len(response.embeddings.ubinary or []), 100) else: self.fail("Expected embeddings_by_type response type") print(response) async def test_embed_batch_v1(self) -> None: # batch more than 96 texts response = await self.co.embed( texts=['hello']*100, model='embed-english-v3.0', input_type="classification", ) if response.response_type == "embeddings_floats": self.assertEqual(len(response.embeddings), 100) else: self.fail("Expected embeddings_floats response type") print(response) @unittest.skip("temp") async def test_embed_job_crud(self) -> None: dataset = await self.co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) result = await self.co.wait(dataset) self.assertEqual(result.dataset.validation_status, "validated") # start an embed job job = await self.co.embed_jobs.create( dataset_id=dataset.id or "", input_type="search_document", model='embed-english-v3.0') print(job) # list embed jobs my_embed_jobs = await self.co.embed_jobs.list() print(my_embed_jobs) emb_result = await self.co.wait(job) self.assertEqual(emb_result.status, "complete") await self.co.embed_jobs.cancel(job.job_id) await self.co.datasets.delete(dataset.id or "") async def test_rerank(self) -> None: docs = [ 'Carson City is the capital city of the American state of Nevada.', 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.', 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] response = await self.co.rerank( model='rerank-v3.5', query='What is the capital of the United States?', documents=docs, top_n=3, ) print(response) @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") async def test_datasets_crud(self) -> None: my_dataset = await self.co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) print(my_dataset) my_datasets = await self.co.datasets.list() print(my_datasets) dataset = await self.co.datasets.get(my_dataset.id or "") print(dataset) await self.co.datasets.delete(my_dataset.id or "") @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") async def test_save_load(self) -> None: my_dataset = await self.co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) result = await self.co.wait(my_dataset) self.co.utils.save_dataset(result.dataset, "dataset.jsonl") # assert files equal self.assertTrue(os.path.exists("dataset.jsonl")) self.assertEqual(open(embed_job, 'rb').read(), open("dataset.jsonl", 'rb').read()) print(result) await self.co.datasets.delete(my_dataset.id or "") async def test_tokenize(self) -> None: response = await self.co.tokenize( text='tokenize me! :D', model="command-a-03-2025", offline=False, ) print(response) async def test_detokenize(self) -> None: response = await self.co.detokenize( tokens=[10104, 12221, 1315, 34, 1420, 69], model="command-a-03-2025", offline=False, ) print(response) @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") async def test_tool_use(self) -> None: tools = [ Tool( name="sales_database", description="Connects to a database about sales volumes", parameter_definitions={ "day": ToolParameterDefinitionsValue( description="Retrieves sales data from this day, formatted as YYYY-MM-DD.", type="str", required=True )} ) ] tool_parameters_response = await self.co.chat( message="How good were the sales on September 29 2023?", tools=tools, model="command-nightly", preamble=""" ## Task Description You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging. ## Style Guide Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling. """ ) if tool_parameters_response.tool_calls is not None: self.assertEqual( tool_parameters_response.tool_calls[0].name, "sales_database") self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { "day": "2023-09-29"}) else: raise ValueError("Expected tool calls to be present") local_tools = { "sales_database": lambda day: { "number_of_sales": 120, "total_revenue": 48500, "average_sale_value": 404.17, "date": "2023-09-29" } } tool_results = [] for tool_call in tool_parameters_response.tool_calls: output = local_tools[tool_call.name](**tool_call.parameters) outputs = [output] tool_results.append(ToolResult( call=tool_call, outputs=outputs )) cited_response = await self.co.chat( message="How good were the sales on September 29?", tools=tools, tool_results=tool_results, force_single_step=True, model="command-a-03-2025", ) self.assertEqual(cited_response.documents, [ { "average_sale_value": "404.17", "date": "2023-09-29", "id": "sales_database:0:0", "number_of_sales": "120", "total_revenue": "48500", } ]) async def test_local_tokenize(self) -> None: response = await self.co.tokenize( model="command-a-03-2025", text="tokenize me! :D" ) print(response) async def test_local_detokenize(self) -> None: response = await self.co.detokenize( model="command-a-03-2025", tokens=[10104, 12221, 1315, 34, 1420, 69] ) print(response) async def test_tokenize_async_context_with_sync_client(self) -> None: # Test that the sync client can be used in an async context. co = cohere.Client(timeout=10000) print(co.tokenize(model="command-a-03-2025", text="tokenize me! :D")) print(co.detokenize(model="command-a-03-2025", tokens=[ 10104, 12221, 1315, 34, 1420, 69])) ================================================ FILE: tests/test_aws_client_unit.py ================================================ """ Unit tests (mocked, no AWS credentials needed) for AWS client fixes. Covers: - Fix 1: SigV4 signing uses the correct host header after URL rewrite - Fix 2: cohere_aws.Client conditionally initializes based on mode - Fix 3: embed() accepts and passes output_dimension and embedding_types """ import inspect import json import os import unittest from unittest.mock import MagicMock, patch import httpx from cohere.manually_maintained.cohere_aws.mode import Mode class TestSigV4HostHeader(unittest.TestCase): """Fix 1: The headers dict passed to AWSRequest for SigV4 signing must contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com.""" def test_sigv4_signs_with_correct_host(self) -> None: captured_aws_request_kwargs: dict = {} mock_aws_request_cls = MagicMock() def capture_aws_request(**kwargs): # type: ignore captured_aws_request_kwargs.update(kwargs) mock_req = MagicMock() mock_req.prepare.return_value = MagicMock( headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"} ) return mock_req mock_aws_request_cls.side_effect = capture_aws_request mock_botocore = MagicMock() mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls mock_botocore.auth.SigV4Auth.return_value = MagicMock() mock_boto3 = MagicMock() mock_session = MagicMock() mock_session.region_name = "us-east-1" mock_session.get_credentials.return_value = MagicMock() mock_boto3.Session.return_value = mock_session with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \ patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3): from cohere.aws_client import map_request_to_bedrock hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1") request = httpx.Request( method="POST", url="https://api.cohere.com/v1/chat", headers={"connection": "keep-alive"}, json={"model": "cohere.command-r-plus-v1:0", "message": "hello"}, ) self.assertEqual(request.url.host, "api.cohere.com") hook(request) self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url)) signed_headers = captured_aws_request_kwargs["headers"] self.assertEqual( signed_headers["host"], "bedrock-runtime.us-east-1.amazonaws.com", ) class TestModeConditionalInit(unittest.TestCase): """Fix 2: cohere_aws.Client should initialize different boto3 clients depending on mode, and default to SAGEMAKER for backwards compat.""" def test_sagemaker_mode_creates_sagemaker_clients(self) -> None: mock_boto3 = MagicMock() mock_sagemaker = MagicMock() with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): from cohere.manually_maintained.cohere_aws.client import Client client = Client(aws_region="us-east-1") self.assertEqual(client.mode, Mode.SAGEMAKER) service_names = [c[0][0] for c in mock_boto3.client.call_args_list] self.assertIn("sagemaker-runtime", service_names) self.assertIn("sagemaker", service_names) self.assertNotIn("bedrock-runtime", service_names) self.assertNotIn("bedrock", service_names) mock_sagemaker.Session.assert_called_once() def test_bedrock_mode_creates_bedrock_clients(self) -> None: mock_boto3 = MagicMock() mock_sagemaker = MagicMock() with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}): from cohere.manually_maintained.cohere_aws.client import Client client = Client(aws_region="us-west-2", mode=Mode.BEDROCK) self.assertEqual(client.mode, Mode.BEDROCK) service_names = [c[0][0] for c in mock_boto3.client.call_args_list] self.assertIn("bedrock-runtime", service_names) self.assertIn("bedrock", service_names) self.assertNotIn("sagemaker-runtime", service_names) self.assertNotIn("sagemaker", service_names) mock_sagemaker.Session.assert_not_called() def test_default_mode_is_sagemaker(self) -> None: from cohere.manually_maintained.cohere_aws.client import Client sig = inspect.signature(Client.__init__) self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER) class TestEmbedV4Params(unittest.TestCase): """Fix 3: embed() should accept output_dimension and embedding_types, pass them through to the request body, and strip them when None.""" @staticmethod def _make_bedrock_client(): # type: ignore mock_boto3 = MagicMock() mock_botocore = MagicMock() captured_body: dict = {} def fake_invoke_model(**kwargs): # type: ignore captured_body.update(json.loads(kwargs["body"])) mock_body = MagicMock() mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode() return {"body": mock_body} mock_bedrock_client = MagicMock() mock_bedrock_client.invoke_model.side_effect = fake_invoke_model def fake_boto3_client(service_name, **kwargs): # type: ignore if service_name == "bedrock-runtime": return mock_bedrock_client return MagicMock() mock_boto3.client.side_effect = fake_boto3_client return mock_boto3, mock_botocore, captured_body def test_embed_accepts_new_params(self) -> None: from cohere.manually_maintained.cohere_aws.client import Client sig = inspect.signature(Client.embed) self.assertIn("output_dimension", sig.parameters) self.assertIn("embedding_types", sig.parameters) self.assertIsNone(sig.parameters["output_dimension"].default) self.assertIsNone(sig.parameters["embedding_types"].default) def test_embed_passes_params_to_bedrock(self) -> None: mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): from cohere.manually_maintained.cohere_aws.client import Client client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) client.embed( texts=["hello world"], input_type="search_document", model_id="cohere.embed-english-v3", output_dimension=256, embedding_types=["float", "int8"], ) self.assertEqual(captured_body["output_dimension"], 256) self.assertEqual(captured_body["embedding_types"], ["float", "int8"]) def test_embed_omits_none_params(self) -> None: mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): from cohere.manually_maintained.cohere_aws.client import Client client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) client.embed( texts=["hello world"], input_type="search_document", model_id="cohere.embed-english-v3", ) self.assertNotIn("output_dimension", captured_body) self.assertNotIn("embedding_types", captured_body) def test_embed_with_embedding_types_returns_dict(self) -> None: """When embedding_types is specified, the API returns embeddings as a dict. The client should return that dict rather than wrapping it in Embeddings.""" mock_boto3 = MagicMock() mock_botocore = MagicMock() by_type_embeddings = {"float": [[0.1, 0.2]], "int8": [[1, 2]]} def fake_invoke_model(**kwargs): # type: ignore mock_body = MagicMock() mock_body.read.return_value = json.dumps({ "embeddings": by_type_embeddings, "response_type": "embeddings_by_type", }).encode() return {"body": mock_body} mock_bedrock_client = MagicMock() mock_bedrock_client.invoke_model.side_effect = fake_invoke_model def fake_boto3_client(service_name, **kwargs): # type: ignore if service_name == "bedrock-runtime": return mock_bedrock_client return MagicMock() mock_boto3.client.side_effect = fake_boto3_client with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): from cohere.manually_maintained.cohere_aws.client import Client client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) result = client.embed( texts=["hello world"], input_type="search_document", model_id="cohere.embed-english-v3", embedding_types=["float", "int8"], ) self.assertIsInstance(result, dict) self.assertEqual(result, by_type_embeddings) ================================================ FILE: tests/test_bedrock_client.py ================================================ import os import unittest import typing import cohere aws_access_key = os.getenv("AWS_ACCESS_KEY") aws_secret_key = os.getenv("AWS_SECRET_KEY") aws_session_token = os.getenv("AWS_SESSION_TOKEN") aws_region = os.getenv("AWS_REGION") endpoint_type = os.getenv("ENDPOINT_TYPE") def _setup_boto3_env(): """Bridge custom test env vars to standard boto3 credential env vars.""" if aws_access_key: os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key if aws_secret_key: os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_key if aws_session_token: os.environ["AWS_SESSION_TOKEN"] = aws_session_token @unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") class TestClient(unittest.TestCase): platform: str = "bedrock" models: typing.Dict[str, str] = { "chat_model": "cohere.command-r-plus-v1:0", "embed_model": "cohere.embed-multilingual-v3", "generate_model": "cohere.command-text-v14", } def setUp(self) -> None: self.client = cohere.BedrockClient( aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ) def test_rerank(self) -> None: if self.platform != "sagemaker": self.skipTest("Only sagemaker supports rerank") docs = [ 'Carson City is the capital city of the American state of Nevada.', 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.', 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] response = self.client.rerank( model=self.models["rerank_model"], query='What is the capital of the United States?', documents=docs, top_n=3, ) self.assertEqual(len(response.results), 3) def test_embed(self) -> None: response = self.client.embed( model=self.models["embed_model"], texts=["I love Cohere!"], input_type="search_document", ) print(response) def test_generate(self) -> None: response = self.client.generate( model=self.models["generate_model"], prompt='Please explain to me how LLMs work', ) print(response) def test_generate_stream(self) -> None: response = self.client.generate_stream( model=self.models["generate_model"], prompt='Please explain to me how LLMs work', ) for event in response: print(event) if event.event_type == "text-generation": print(event.text, end='') def test_chat(self) -> None: response = self.client.chat( model=self.models["chat_model"], message='Please explain to me how LLMs work', ) print(response) self.assertIsNotNone(response.text) self.assertIsNotNone(response.generation_id) self.assertIsNotNone(response.finish_reason) self.assertIsNotNone(response.meta) if response.meta is not None: self.assertIsNotNone(response.meta.tokens) if response.meta.tokens is not None: self.assertIsNotNone(response.meta.tokens.input_tokens) self.assertIsNotNone(response.meta.tokens.output_tokens) self.assertIsNotNone(response.meta.billed_units) if response.meta.billed_units is not None: self.assertIsNotNone(response.meta.billed_units.input_tokens) self.assertIsNotNone(response.meta.billed_units.input_tokens) def test_chat_stream(self) -> None: response_types = set() response = self.client.chat_stream( model=self.models["chat_model"], message='Please explain to me how LLMs work', ) for event in response: response_types.add(event.event_type) if event.event_type == "text-generation": print(event.text, end='') self.assertIsNotNone(event.text) if event.event_type == "stream-end": self.assertIsNotNone(event.finish_reason) self.assertIsNotNone(event.response) self.assertIsNotNone(event.response.text) self.assertSetEqual(response_types, {"text-generation", "stream-end"}) @unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") class TestBedrockClientV2(unittest.TestCase): """Integration tests for BedrockClientV2 (httpx-based). Fix 1 validation: If these pass, SigV4 signing uses the correct host header, since the request would fail with a signature mismatch otherwise. """ def setUp(self) -> None: self.client = cohere.BedrockClientV2( aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ) def test_embed(self) -> None: response = self.client.embed( model="cohere.embed-multilingual-v3", texts=["I love Cohere!"], input_type="search_document", embedding_types=["float"], ) self.assertIsNotNone(response) def test_embed_with_output_dimension(self) -> None: response = self.client.embed( model="cohere.embed-english-v3", texts=["I love Cohere!"], input_type="search_document", embedding_types=["float"], output_dimension=256, ) self.assertIsNotNone(response) @unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") class TestCohereAwsBedrockClient(unittest.TestCase): """Integration tests for cohere_aws.Client in Bedrock mode (boto3-based). Validates: - Fix 2: Client can be initialized with mode=BEDROCK without importing sagemaker - Fix 3: embed() accepts output_dimension and embedding_types """ client: typing.Any = None @classmethod def setUpClass(cls) -> None: _setup_boto3_env() from cohere.manually_maintained.cohere_aws.client import Client from cohere.manually_maintained.cohere_aws.mode import Mode cls.client = Client(aws_region=aws_region, mode=Mode.BEDROCK) def test_client_is_bedrock_mode(self) -> None: from cohere.manually_maintained.cohere_aws.mode import Mode self.assertEqual(self.client.mode, Mode.BEDROCK) def test_embed(self) -> None: response = self.client.embed( texts=["I love Cohere!"], input_type="search_document", model_id="cohere.embed-multilingual-v3", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) self.assertGreater(len(response.embeddings), 0) def test_embed_with_embedding_types(self) -> None: response = self.client.embed( texts=["I love Cohere!"], input_type="search_document", model_id="cohere.embed-multilingual-v3", embedding_types=["float"], ) self.assertIsNotNone(response) # When embedding_types is passed, the response is a raw dict self.assertIsInstance(response, dict) self.assertIn("float", response) def test_embed_with_output_dimension(self) -> None: response = self.client.embed( texts=["I love Cohere!"], input_type="search_document", model_id="cohere.embed-english-v3", output_dimension=256, embedding_types=["float"], ) self.assertIsNotNone(response) # When embedding_types is passed, the response is a raw dict self.assertIsInstance(response, dict) self.assertIn("float", response) def test_embed_without_new_params(self) -> None: """Backwards compat: embed() still works without the new v4 params.""" response = self.client.embed( texts=["I love Cohere!"], input_type="search_document", model_id="cohere.embed-multilingual-v3", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) ================================================ FILE: tests/test_client.py ================================================ import json import os import unittest import cohere from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ ToolParameterDefinitionsValue, ToolResult, ChatbotMessage, UserMessage, JsonObjectResponseFormat co = cohere.Client(timeout=10000) package_dir = os.path.dirname(os.path.abspath(__file__)) embed_job = os.path.join(package_dir, 'embed_job.jsonl') class TestClient(unittest.TestCase): def test_token_falls_back_on_env_variable(self) -> None: cohere.Client(api_key=None) cohere.Client(None) def test_context_manager(self) -> None: with cohere.Client(api_key="xxx") as client: self.assertIsNotNone(client) def test_chat(self) -> None: chat = co.chat( chat_history=[ UserMessage( message="Who discovered gravity?"), ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", ) print(chat) def test_chat_stream(self) -> None: stream = co.chat_stream( chat_history=[ UserMessage( message="Who discovered gravity?"), ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", ) events = set() for chat_event in stream: events.add(chat_event.event_type) if chat_event.event_type == "text-generation": print(chat_event.text) self.assertTrue("text-generation" in events) self.assertTrue("stream-start" in events) self.assertTrue("stream-end" in events) def test_stream_equals_true(self) -> None: with self.assertRaises(ValueError): co.chat( stream=True, # type: ignore message="What year was he born?", ) def test_deprecated_fn(self) -> None: with self.assertRaises(ValueError): co.check_api_key("dummy", dummy="dummy") # type: ignore def test_moved_fn(self) -> None: with self.assertRaises(ValueError): co.list_connectors("dummy", dummy="dummy") # type: ignore def test_embed(self) -> None: response = co.embed( texts=['hello', 'goodbye'], model='embed-english-v3.0', input_type="classification", embedding_types=["float", "int8", "uint8", "binary", "ubinary"] ) if response.response_type == "embeddings_by_type": self.assertIsNotNone(response.embeddings.float) # type: ignore self.assertIsNotNone(response.embeddings.float_) if response.embeddings.float_ is not None: self.assertEqual(type(response.embeddings.float_[0][0]), float) if response.embeddings.int8 is not None: self.assertEqual(type(response.embeddings.int8[0][0]), int) if response.embeddings.uint8 is not None: self.assertEqual(type(response.embeddings.uint8[0][0]), int) if response.embeddings.binary is not None: self.assertEqual(type(response.embeddings.binary[0][0]), int) if response.embeddings.ubinary is not None: self.assertEqual(type(response.embeddings.ubinary[0][0]), int) print(response) def test_image_embed(self) -> None: response = co.embed( images=['data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII='], model='embed-multilingual-v3.0', input_type="image", embedding_types=["float"] ) if response.response_type == "embeddings_by_type": self.assertIsNotNone(response.embeddings.float) # type: ignore self.assertIsNotNone(response.embeddings.float_) if response.embeddings.float_ is not None: self.assertEqual(type(response.embeddings.float_[0][0]), float) if response.embeddings.int8 is not None: self.assertEqual(type(response.embeddings.int8[0][0]), int) if response.embeddings.uint8 is not None: self.assertEqual(type(response.embeddings.uint8[0][0]), int) if response.embeddings.binary is not None: self.assertEqual(type(response.embeddings.binary[0][0]), int) if response.embeddings.ubinary is not None: self.assertEqual(type(response.embeddings.ubinary[0][0]), int) print(response) def test_embed_batch_types(self) -> None: # batch more than 96 texts response = co.embed( texts=['hello'] * 100, model='embed-english-v3.0', input_type="classification", embedding_types=["float", "int8", "uint8", "binary", "ubinary"] ) if response.response_type == "embeddings_by_type": self.assertEqual(len(response.texts or []), 100) self.assertEqual(len(response.embeddings.float_ or []), 100) self.assertEqual(len(response.embeddings.int8 or []), 100) self.assertEqual(len(response.embeddings.uint8 or []), 100) self.assertEqual(len(response.embeddings.binary or []), 100) self.assertEqual(len(response.embeddings.ubinary or []), 100) else: self.fail("Expected embeddings_by_type response type") print(response) def test_embed_batch_v1(self) -> None: # batch more than 96 texts response = co.embed( texts=['hello'] * 100, model='embed-english-v3.0', input_type="classification", ) if response.response_type == "embeddings_floats": self.assertEqual(len(response.embeddings), 100) else: self.fail("Expected embeddings_floats response type") print(response) @unittest.skip("temp") def test_embed_job_crud(self) -> None: dataset = co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) result = co.wait(dataset) self.assertEqual(result.dataset.validation_status, "validated") # start an embed job job = co.embed_jobs.create( dataset_id=dataset.id or "", input_type="search_document", model='embed-english-v3.0') print(job) # list embed jobs my_embed_jobs = co.embed_jobs.list() print(my_embed_jobs) emb_result = co.wait(job) self.assertEqual(emb_result.status, "complete") co.embed_jobs.cancel(job.job_id) co.datasets.delete(dataset.id or "") def test_rerank(self) -> None: docs = [ 'Carson City is the capital city of the American state of Nevada.', 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.', 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] response = co.rerank( model='rerank-v3.5', query='What is the capital of the United States?', documents=docs, top_n=3, ) print(response) @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") def test_datasets_crud(self) -> None: my_dataset = co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) print(my_dataset) my_datasets = co.datasets.list() print(my_datasets) dataset = co.datasets.get(my_dataset.id or "") print(dataset) co.datasets.delete(my_dataset.id or "") @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") def test_save_load(self) -> None: my_dataset = co.datasets.create( name="test", type="embed-input", data=open(embed_job, 'rb'), ) result = co.wait(my_dataset) co.utils.save_dataset(result.dataset, "dataset.jsonl") # assert files equal self.assertTrue(os.path.exists("dataset.jsonl")) self.assertEqual(open(embed_job, 'rb').read(), open("dataset.jsonl", 'rb').read()) print(result) co.datasets.delete(my_dataset.id or "") def test_tokenize(self) -> None: response = co.tokenize( text='tokenize me! :D', model='command-a-03-2025', offline=False, ) print(response) def test_detokenize(self) -> None: response = co.detokenize( tokens=[10104, 12221, 1315, 34, 1420, 69], model="command-a-03-2025", offline=False, ) print(response) @unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.") def test_tool_use(self) -> None: tools = [ Tool( name="sales_database", description="Connects to a database about sales volumes", parameter_definitions={ "day": ToolParameterDefinitionsValue( description="Retrieves sales data from this day, formatted as YYYY-MM-DD.", type="str", required=True )} ) ] tool_parameters_response = co.chat( message="How good were the sales on September 29 2023?", tools=tools, model="command-nightly", preamble=""" ## Task Description You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging. ## Style Guide Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling. """ ) if tool_parameters_response.tool_calls is not None: self.assertEqual( tool_parameters_response.tool_calls[0].name, "sales_database") self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { "day": "2023-09-29"}) else: raise ValueError("Expected tool calls to be present") local_tools = { "sales_database": lambda day: { "number_of_sales": 120, "total_revenue": 48500, "average_sale_value": 404.17, "date": "2023-09-29" } } tool_results = [] for tool_call in tool_parameters_response.tool_calls: output = local_tools[tool_call.name](**tool_call.parameters) outputs = [output] tool_results.append(ToolResult( call=tool_call, outputs=outputs )) cited_response = co.chat( message="How good were the sales on September 29?", tools=tools, tool_results=tool_results, force_single_step=True, model="command-nightly", ) self.assertEqual(cited_response.documents, [ { "average_sale_value": "404.17", "date": "2023-09-29", "id": "sales_database:0:0", "number_of_sales": "120", "total_revenue": "48500", } ]) def test_local_tokenize(self) -> None: response = co.tokenize( model="command-a-03-2025", text="tokenize me! :D" ) print(response) def test_local_detokenize(self) -> None: response = co.detokenize( model="command-a-03-2025", tokens=[10104, 12221, 1315, 34, 1420, 69] ) print(response) ================================================ FILE: tests/test_client_init.py ================================================ import os import typing import unittest import cohere from cohere import ToolMessage, UserMessage, AssistantMessage import importlib.util HAS_BOTO3 = importlib.util.find_spec("boto3") is not None class TestClientInit(unittest.TestCase): @unittest.skipUnless(HAS_BOTO3, "boto3 not installed") def test_aws_inits(self) -> None: cohere.BedrockClient() cohere.BedrockClientV2() cohere.SagemakerClient() cohere.SagemakerClientV2() def test_inits(self) -> None: cohere.Client(api_key="n/a") cohere.ClientV2(api_key="n/a") ================================================ FILE: tests/test_client_v2.py ================================================ import os import typing import unittest import cohere from cohere import ToolMessage, UserMessage, AssistantMessage co = cohere.ClientV2(timeout=10000) package_dir = os.path.dirname(os.path.abspath(__file__)) embed_job = os.path.join(package_dir, "embed_job.jsonl") class TestClientV2(unittest.TestCase): def test_chat(self) -> None: response = co.chat( model="command-a-03-2025", messages=[cohere.UserChatMessageV2(content="hello world!")]) print(response.message) def test_chat_stream(self) -> None: stream = co.chat_stream( model="command-a-03-2025", messages=[cohere.UserChatMessageV2(content="hello world!")]) events = set() for chat_event in stream: if chat_event is not None: events.add(chat_event.type) if chat_event.type == "content-delta": print(chat_event.delta) self.assertTrue("message-start" in events) self.assertTrue("content-start" in events) self.assertTrue("content-delta" in events) self.assertTrue("content-end" in events) self.assertTrue("message-end" in events) def test_legacy_methods_available(self) -> None: self.assertTrue(hasattr(co, "generate")) self.assertTrue(callable(getattr(co, "generate"))) self.assertTrue(hasattr(co, "generate_stream")) self.assertTrue(callable(getattr(co, "generate_stream"))) @unittest.skip("Skip v2 test for now") def test_chat_documents(self) -> None: from cohere import Document documents = [ Document(data={"title": "widget sales 2019", "text": "1 million"}), Document(data={"title": "widget sales 2020", "text": "2 million"}), Document(data={"title": "widget sales 2021", "text": "4 million"}), ] response = co.chat( messages=[cohere.UserChatMessageV2( content=[cohere.TextContent(text="how many widges were sold in 2020?")], )], model="command-a-03-2025", documents=documents, ) print(response.message) @unittest.skip("Skip v2 test for now") def test_chat_tools(self) -> None: from typing import Sequence get_weather_tool = cohere.ToolV2Function( name="get_weather", description="gets the weather of a given location", parameters={ "type": "object", "properties": { "location": { "type": "str", "description": "the location to get weather, example: San Fransisco, CA", } }, "required": ["location"], }, ) tools = [cohere.ToolV2(type="function", function=get_weather_tool)] messages: cohere.ChatMessages = [ cohere.UserChatMessageV2(content="what is the weather in Toronto?") ] res = co.chat(model="command-a-03-2025", tools=tools, messages=messages) # call the get_weather tool tool_result = {"temperature": "30C"} tool_content: Sequence[cohere.TextToolContent] = [cohere.TextToolContent(text="The weather in Toronto is 30C")] # Use the first text content from the response if available, else fallback to str assistant_content = res.message.content[0].text if (hasattr(res.message, 'content') and isinstance(res.message.content, list) and len(res.message.content) > 0 and hasattr(res.message.content[0], 'text')) else str(res.message) messages.append(cohere.AssistantChatMessageV2(content=[cohere.TextAssistantMessageV2ContentItem(text=assistant_content)])) if res.message.tool_calls is not None and res.message.tool_calls[0].id is not None: messages.append(cohere.ToolChatMessageV2( tool_call_id=res.message.tool_calls[0].id, content=list(tool_content))) res = co.chat(tools=tools, messages=messages, model="command-a-03-2025") print(res.message) ================================================ FILE: tests/test_embed_streaming.py ================================================ """Tests for memory-efficient embed_stream functionality. All embed_stream code lives in manually maintained files (.fernignore protected): - src/cohere/client.py — Client.embed_stream() - src/cohere/manually_maintained/streaming_embed.py — StreamedEmbedding, extraction helpers """ import unittest from cohere.manually_maintained.streaming_embed import ( StreamedEmbedding, extract_embeddings_from_response, ) from cohere.config import embed_stream_batch_size class TestStreamedEmbedding(unittest.TestCase): """Test the StreamedEmbedding dataclass.""" def test_creation(self): emb = StreamedEmbedding(index=0, embedding=[0.1, 0.2], embedding_type="float", text="hello") self.assertEqual(emb.index, 0) self.assertEqual(emb.embedding, [0.1, 0.2]) self.assertEqual(emb.embedding_type, "float") self.assertEqual(emb.text, "hello") def test_text_optional(self): emb = StreamedEmbedding(index=0, embedding=[0.1], embedding_type="float") self.assertIsNone(emb.text) class TestExtractEmbeddings(unittest.TestCase): """Test extract_embeddings_from_response for V1 and V2 formats.""" def test_v1_embeddings_floats(self): """V1 embeddings_floats response returns flat float embeddings.""" response = { "response_type": "embeddings_floats", "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], } results = list(extract_embeddings_from_response(response, ["hello", "world"])) self.assertEqual(len(results), 2) self.assertEqual(results[0].index, 0) self.assertEqual(results[0].embedding, [0.1, 0.2, 0.3]) self.assertEqual(results[0].embedding_type, "float") self.assertEqual(results[0].text, "hello") self.assertEqual(results[1].index, 1) self.assertEqual(results[1].text, "world") def test_v1_embeddings_by_type(self): """V1 embeddings_by_type response returns typed embeddings.""" response = { "response_type": "embeddings_by_type", "embeddings": { "float_": [[0.1, 0.2], [0.3, 0.4]], "int8": [[1, 2], [3, 4]], }, } results = list(extract_embeddings_from_response(response, ["a", "b"])) # 2 texts * 2 types = 4 embeddings self.assertEqual(len(results), 4) float_results = [r for r in results if r.embedding_type == "float"] int8_results = [r for r in results if r.embedding_type == "int8"] self.assertEqual(len(float_results), 2) self.assertEqual(len(int8_results), 2) def test_v2_response_format(self): """V2 response (no response_type) returns dict embeddings.""" response = { "embeddings": { "float_": [[0.1, 0.2], [0.3, 0.4]], }, } results = list(extract_embeddings_from_response(response, ["x", "y"])) self.assertEqual(len(results), 2) self.assertEqual(results[0].embedding_type, "float") self.assertEqual(results[0].text, "x") def test_global_offset(self): """Global offset adjusts indices for batched processing.""" response = { "response_type": "embeddings_floats", "embeddings": [[0.1], [0.2]], } results = list(extract_embeddings_from_response(response, ["c", "d"], global_offset=100)) self.assertEqual(results[0].index, 100) self.assertEqual(results[1].index, 101) def test_empty_embeddings(self): """Empty response yields nothing.""" response = {"response_type": "embeddings_floats", "embeddings": []} results = list(extract_embeddings_from_response(response, [])) self.assertEqual(results, []) def test_texts_shorter_than_embeddings(self): """Text is None when batch_texts runs out.""" response = { "response_type": "embeddings_floats", "embeddings": [[0.1], [0.2], [0.3]], } results = list(extract_embeddings_from_response(response, ["only_one"])) self.assertEqual(results[0].text, "only_one") self.assertIsNone(results[1].text) self.assertIsNone(results[2].text) class TestBatchSizeConstant(unittest.TestCase): """Test that batch_size defaults come from config, not magic numbers.""" def test_default_batch_size_matches_api_limit(self): self.assertEqual(embed_stream_batch_size, 96) if __name__ == "__main__": unittest.main() ================================================ FILE: tests/test_embed_utils.py ================================================ import unittest from cohere import EmbeddingsByTypeEmbedResponse, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \ ApiMetaApiVersion, EmbeddingsFloatsEmbedResponse from cohere.utils import merge_embed_responses ebt_1 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1", embeddings=EmbedByTypeResponseEmbeddings( float_=[[0, 1, 2], [3, 4, 5]], int8=[[0, 1, 2], [3, 4, 5]], uint8=[[0, 1, 2], [3, 4, 5]], binary=[[0, 1, 2], [3, 4, 5]], ubinary=[[0, 1, 2], [3, 4, 5]], ), texts=["hello", "goodbye"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=1, output_tokens=1, search_units=1, classifications=1 ), warnings=["test_warning_1"] ) ) ebt_2 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="2", embeddings=EmbedByTypeResponseEmbeddings( float_=[[7, 8, 9], [10, 11, 12]], int8=[[7, 8, 9], [10, 11, 12]], uint8=[[7, 8, 9], [10, 11, 12]], binary=[[7, 8, 9], [10, 11, 12]], ubinary=[[7, 8, 9], [10, 11, 12]], ), texts=["bye", "seeya"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=2, output_tokens=2, search_units=2, classifications=2 ), warnings=["test_warning_1", "test_warning_2"] ) ) ebt_partial_1 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1", embeddings=EmbedByTypeResponseEmbeddings( float_=[[0, 1, 2], [3, 4, 5]], int8=[[0, 1, 2], [3, 4, 5]], binary=[[5, 6, 7], [8, 9, 10]], ), texts=["hello", "goodbye"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=1, output_tokens=1, search_units=1, classifications=1 ), warnings=["test_warning_1"] ) ) ebt_partial_2 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="2", embeddings=EmbedByTypeResponseEmbeddings( float_=[[7, 8, 9], [10, 11, 12]], int8=[[7, 8, 9], [10, 11, 12]], binary=[[14, 15, 16], [17, 18, 19]], ), texts=["bye", "seeya"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=2, output_tokens=2, search_units=2, classifications=2 ), warnings=["test_warning_1", "test_warning_2"] ) ) ebf_1 = EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="1", texts=["hello", "goodbye"], embeddings=[[0, 1, 2], [3, 4, 5]], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=1, output_tokens=1, search_units=1, classifications=1 ), warnings=["test_warning_1"] ) ) ebf_2 = EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="2", texts=["bye", "seeya"], embeddings=[[7, 8, 9], [10, 11, 12]], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=2, output_tokens=2, search_units=2, classifications=2 ), warnings=["test_warning_1", "test_warning_2"] ) ) class TestClient(unittest.TestCase): def test_merge_embeddings_by_type(self) -> None: resp = merge_embed_responses([ ebt_1, ebt_2 ]) if resp.meta is None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) self.assertEqual(resp, EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1, 2", embeddings=EmbedByTypeResponseEmbeddings( float_=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], uint8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], binary=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], ubinary=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], ), texts=["hello", "goodbye", "bye", "seeya"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=3, output_tokens=3, search_units=3, classifications=3 ), warnings=resp.meta.warnings # order ignored ) )) def test_merge_embeddings_floats(self) -> None: resp = merge_embed_responses([ ebf_1, ebf_2 ]) if resp.meta is None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) self.assertEqual(resp, EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="1, 2", texts=["hello", "goodbye", "bye", "seeya"], embeddings=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=3, output_tokens=3, search_units=3, classifications=3 ), warnings=resp.meta.warnings # order ignored ) )) def test_merge_partial_embeddings_floats(self) -> None: resp = merge_embed_responses([ ebt_partial_1, ebt_partial_2 ]) if resp.meta is None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) self.assertEqual(resp, EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1, 2", embeddings=EmbedByTypeResponseEmbeddings( float_=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]], binary=[[5, 6, 7], [8, 9, 10], [14, 15, 16], [17, 18, 19]], ), texts=["hello", "goodbye", "bye", "seeya"], meta=ApiMeta( api_version=ApiMetaApiVersion(version="1"), billed_units=ApiMetaBilledUnits( input_tokens=3, output_tokens=3, search_units=3, classifications=3 ), warnings=resp.meta.warnings # order ignored ) )) ================================================ FILE: tests/test_oci_client.py ================================================ """Integration and unit tests for OCI Generative AI client. All integration tests are validated against the live OCI Generative AI inference layer (us-chicago-1). The OciClientV2 uses the V2 Cohere API format (COHEREV2) and communicates with the OCI inference endpoint at: https://inference.generativeai.{region}.oci.oraclecloud.com Integration test coverage: V1 API (OciClient — Command R family): Test Model What it proves ------------------------------- -------------------------- ------------------------------------------ test_embed embed-english-v3.0 V1 embed returns 2x 1024-dim float vectors test_chat command-r-08-2024 V1 chat returns text with COHERE apiFormat test_chat_stream command-r-08-2024 V1 streaming with text-generation events V2 API (OciClientV2 — Command A family): Test Model What it proves ------------------------------- -------------------------- ------------------------------------------ test_embed_v2 embed-english-v3.0 V2 embed returns dict with float_ key test_embed_with_model_prefix_v2 cohere.embed-english-v3.0 Model normalization works test_chat_v2 command-a-03-2025 V2 chat returns message with COHEREV2 format test_chat_stream_v2 command-a-03-2025 V2 SSE streaming with content-delta events test_command_a_chat command-a-03-2025 Command A chat via V2 Cross-cutting: Test Model What it proves ------------------------------- -------------------------- ------------------------------------------ test_config_file_auth embed-english-v3.0 API key auth from config file test_custom_profile_auth embed-english-v3.0 Custom OCI profile auth test_embed_english_v3 embed-english-v3.0 1024-dim embeddings test_embed_multilingual_v3 embed-multilingual-v3.0 Multilingual model works test_invalid_model invalid-model-name Error handling works test_missing_compartment_id -- Raises TypeError Requirements: 1. OCI SDK installed: pip install oci 2. OCI credentials configured in ~/.oci/config 3. TEST_OCI environment variable set to run 4. OCI_COMPARTMENT_ID environment variable with valid OCI compartment OCID 5. OCI_REGION environment variable (optional, defaults to us-chicago-1) Run with: TEST_OCI=1 OCI_COMPARTMENT_ID=ocid1.compartment.oc1... pytest tests/test_oci_client.py """ import os import sys import tempfile import types import unittest from unittest.mock import MagicMock, mock_open, patch import cohere if "tokenizers" not in sys.modules: tokenizers_stub = types.ModuleType("tokenizers") tokenizers_stub.Tokenizer = object # type: ignore[attr-defined] sys.modules["tokenizers"] = tokenizers_stub if "fastavro" not in sys.modules: fastavro_stub = types.ModuleType("fastavro") fastavro_stub.parse_schema = lambda schema: schema # type: ignore[attr-defined] fastavro_stub.reader = lambda *args, **kwargs: iter(()) # type: ignore[attr-defined] fastavro_stub.writer = lambda *args, **kwargs: None # type: ignore[attr-defined] sys.modules["fastavro"] = fastavro_stub if "httpx_sse" not in sys.modules: httpx_sse_stub = types.ModuleType("httpx_sse") httpx_sse_stub.connect_sse = lambda *args, **kwargs: None # type: ignore[attr-defined] sys.modules["httpx_sse"] = httpx_sse_stub @unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") class TestOciClient(unittest.TestCase): """Test OciClient (V1 API) with OCI Generative AI.""" def setUp(self): compartment_id = os.getenv("OCI_COMPARTMENT_ID") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") region = os.getenv("OCI_REGION", "us-chicago-1") profile = os.getenv("OCI_PROFILE", "DEFAULT") self.client = cohere.OciClient( oci_region=region, oci_compartment_id=compartment_id, oci_profile=profile, ) def test_embed(self): """Test embedding with V1 client.""" response = self.client.embed( model="embed-english-v3.0", texts=["Hello world", "Cohere on OCI"], input_type="search_document", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) self.assertEqual(len(response.embeddings), 2) self.assertEqual(len(response.embeddings[0]), 1024) self.assertEqual(response.response_type, "embeddings_floats") def test_chat(self): """Test V1 chat with Command R.""" response = self.client.chat( model="command-r-08-2024", message="What is 2+2? Answer with just the number.", ) self.assertIsNotNone(response) self.assertIsNotNone(response.text) self.assertIn("4", response.text) def test_chat_stream(self): """Test V1 streaming chat terminates and produces correct events.""" events = [] for event in self.client.chat_stream( model="command-r-08-2024", message="Count from 1 to 3.", ): events.append(event) self.assertTrue(len(events) > 0) text_events = [e for e in events if hasattr(e, "text") and e.text] self.assertTrue(len(text_events) > 0) # Verify stream terminates with correct event lifecycle event_types = [getattr(e, "event_type", None) for e in events] self.assertEqual(event_types[0], "stream-start") self.assertEqual(event_types[-1], "stream-end") @unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") class TestOciClientV2(unittest.TestCase): """Test OciClientV2 (v2 API) with OCI Generative AI.""" def setUp(self): """Set up OCI v2 client for each test.""" compartment_id = os.getenv("OCI_COMPARTMENT_ID") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") region = os.getenv("OCI_REGION", "us-chicago-1") profile = os.getenv("OCI_PROFILE", "DEFAULT") self.client = cohere.OciClientV2( oci_region=region, oci_compartment_id=compartment_id, oci_profile=profile, ) def test_embed_v2(self): """Test embedding with v2 client.""" response = self.client.embed( model="embed-english-v3.0", texts=["Hello from v2", "Second text"], input_type="search_document", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) # V2 returns embeddings as a dict with "float" key self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_), 2) # Verify embedding dimensions (1024 for embed-english-v3.0) self.assertEqual(len(response.embeddings.float_[0]), 1024) self.assertEqual(response.response_type, "embeddings_by_type") def test_embed_with_model_prefix_v2(self): """Test embedding with 'cohere.' model prefix on v2 client.""" response = self.client.embed( model="cohere.embed-english-v3.0", texts=["Test with prefix"], input_type="search_document", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_), 1) def test_chat_v2(self): """Test chat with v2 client.""" response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "Say hello"}], ) self.assertIsNotNone(response) self.assertIsNotNone(response.message) def test_chat_vision_v2(self): """Test vision with inline image on Command A Vision.""" import base64, struct, zlib # Create a minimal 1x1 red PNG raw = b'\x00\xff\x00\x00' compressed = zlib.compress(raw) def chunk(ctype, data): c = ctype + data return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff) ihdr = struct.pack('>IIBBBBB', 1, 1, 8, 2, 0, 0, 0) png = b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', compressed) + chunk(b'IEND', b'') img_b64 = base64.b64encode(png).decode() response = self.client.chat( model="command-a-vision", messages=[{ "role": "user", "content": [ {"type": "text", "text": "What color is this image? Reply with one word."}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, ], }], ) self.assertIsNotNone(response) self.assertIsNotNone(response.message) self.assertTrue(len(response.message.content) > 0) # The 1x1 red pixel should be identified as red self.assertIn("red", response.message.content[0].text.lower()) def test_chat_tool_use_v2(self): """Test tool use with v2 client on OCI on-demand inference.""" response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "What's the weather in Toronto?"}], max_tokens=200, tools=[{ "type": "function", "function": { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"} }, "required": ["location"], }, }, }], ) self.assertIsNotNone(response) self.assertIsNotNone(response.message) self.assertEqual(response.finish_reason, "TOOL_CALL") self.assertTrue(len(response.message.tool_calls) > 0) tool_call = response.message.tool_calls[0] self.assertEqual(tool_call.function.name, "get_weather") self.assertIn("Toronto", tool_call.function.arguments) def test_chat_tool_use_response_type_lowered(self): """Test that tool_call type is lowercased in response (OCI returns FUNCTION).""" response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "What's the weather in Toronto?"}], max_tokens=200, tools=[{ "type": "function", "function": { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"} }, "required": ["location"], }, }, }], ) self.assertEqual(response.finish_reason, "TOOL_CALL") tool_call = response.message.tool_calls[0] # OCI returns "FUNCTION" — SDK must lowercase to "function" for Cohere compat self.assertEqual(tool_call.type, "function") def test_chat_multi_turn_tool_use_v2(self): """Test multi-turn tool use: send tool result back after tool call.""" # Step 1: Get a tool call response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "What's the weather in Toronto?"}], max_tokens=200, tools=[{ "type": "function", "function": { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"} }, "required": ["location"], }, }, }], ) self.assertEqual(response.finish_reason, "TOOL_CALL") tool_call = response.message.tool_calls[0] # Step 2: Send tool result back response2 = self.client.chat( model="command-a-03-2025", messages=[ {"role": "user", "content": "What's the weather in Toronto?"}, { "role": "assistant", "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": "get_weather", "arguments": tool_call.function.arguments}}], "tool_plan": response.message.tool_plan, }, { "role": "tool", "tool_call_id": tool_call.id, "content": [{"type": "text", "text": "15°C, sunny"}], }, ], max_tokens=200, tools=[{ "type": "function", "function": { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"} }, "required": ["location"], }, }, }], ) self.assertIsNotNone(response2.message) # Model should respond with text incorporating the tool result self.assertTrue(len(response2.message.content) > 0) def test_chat_safety_mode_v2(self): """Test that safety_mode is uppercased for OCI.""" # Cohere SDK enum values are already uppercase, but test lowercase too response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "Say hi"}], safety_mode="STRICT", ) self.assertIsNotNone(response.message) def test_chat_stream_v2(self): """Test V2 streaming chat terminates and produces correct event lifecycle.""" events = [] for event in self.client.chat_stream( model="command-a-03-2025", messages=[{"role": "user", "content": "Count from 1 to 3"}], ): events.append(event) self.assertTrue(len(events) > 0) # Verify full event lifecycle: message-start → content-start → content-delta(s) → content-end → message-end event_types = [e.type for e in events] self.assertEqual(event_types[0], "message-start") self.assertIn("content-start", event_types) self.assertIn("content-delta", event_types) self.assertIn("content-end", event_types) self.assertEqual(event_types[-1], "message-end") # Verify we can extract text from content-delta events full_text = "" for event in events: if ( hasattr(event, "delta") and event.delta and hasattr(event.delta, "message") and event.delta.message and hasattr(event.delta.message, "content") and event.delta.message.content and hasattr(event.delta.message.content, "text") and event.delta.message.content.text is not None ): full_text += event.delta.message.content.text self.assertTrue(len(full_text) > 0) @unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") class TestOciClientAuthentication(unittest.TestCase): """Test different OCI authentication methods.""" def test_config_file_auth(self): """Test authentication using OCI config file.""" compartment_id = os.getenv("OCI_COMPARTMENT_ID") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") profile = os.getenv("OCI_PROFILE", "DEFAULT") client = cohere.OciClientV2( oci_region="us-chicago-1", oci_compartment_id=compartment_id, oci_profile=profile, ) # Test with a simple embed call response = client.embed( model="embed-english-v3.0", texts=["Auth test"], input_type="search_document", ) self.assertIsNotNone(response) self.assertIsNotNone(response.embeddings) def test_custom_profile_auth(self): """Test authentication using custom OCI profile.""" compartment_id = os.getenv("OCI_COMPARTMENT_ID") profile = os.getenv("OCI_PROFILE", "DEFAULT") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") client = cohere.OciClientV2( oci_profile=profile, oci_region="us-chicago-1", oci_compartment_id=compartment_id, ) response = client.embed( model="embed-english-v3.0", texts=["Profile auth test"], input_type="search_document", ) self.assertIsNotNone(response) @unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") class TestOciClientErrors(unittest.TestCase): """Test error handling in OCI client.""" def test_missing_compartment_id(self): """Test error when compartment ID is missing.""" with self.assertRaises(TypeError): cohere.OciClientV2( oci_region="us-chicago-1", # Missing oci_compartment_id ) def test_invalid_model(self): """Test error handling with invalid model.""" compartment_id = os.getenv("OCI_COMPARTMENT_ID") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") profile = os.getenv("OCI_PROFILE", "DEFAULT") client = cohere.OciClientV2( oci_region="us-chicago-1", oci_compartment_id=compartment_id, oci_profile=profile, ) # OCI should return an error for invalid model with self.assertRaises(Exception): client.embed( model="invalid-model-name", texts=["Test"], input_type="search_document", ) @unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") class TestOciClientModels(unittest.TestCase): """Test different Cohere models on OCI.""" def setUp(self): """Set up OCI client for each test.""" compartment_id = os.getenv("OCI_COMPARTMENT_ID") if not compartment_id: self.skipTest("OCI_COMPARTMENT_ID not set") region = os.getenv("OCI_REGION", "us-chicago-1") profile = os.getenv("OCI_PROFILE", "DEFAULT") self.client = cohere.OciClientV2( oci_region=region, oci_compartment_id=compartment_id, oci_profile=profile, ) def test_embed_english_v3(self): """Test embed-english-v3.0 model.""" response = self.client.embed( model="embed-english-v3.0", texts=["Test"], input_type="search_document", ) self.assertIsNotNone(response.embeddings) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 1024) def test_embed_multilingual_v3(self): """Test embed-multilingual-v3.0 model.""" response = self.client.embed( model="embed-multilingual-v3.0", texts=["Test"], input_type="search_document", ) self.assertIsNotNone(response.embeddings) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 1024) def test_command_a_chat(self): """Test command-a-03-2025 model for chat.""" response = self.client.chat( model="command-a-03-2025", messages=[{"role": "user", "content": "Hello"}], ) self.assertIsNotNone(response.message) def test_embed_english_light_v3(self): """Test embed-english-light-v3.0 returns 384-dim vectors.""" response = self.client.embed( model="embed-english-light-v3.0", texts=["Hello world"], input_type="search_document", ) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 384) def test_embed_multilingual_light_v3(self): """Test embed-multilingual-light-v3.0 returns 384-dim vectors.""" response = self.client.embed( model="embed-multilingual-light-v3.0", texts=["Bonjour le monde"], input_type="search_document", ) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 384) def test_embed_search_query_input_type(self): """Test embed with search_query input_type (distinct from search_document).""" response = self.client.embed( model="embed-english-v3.0", texts=["What is the capital of France?"], input_type="search_query", ) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 1024) def test_embed_with_embedding_types(self): """Test embed with explicit embedding_types parameter.""" response = self.client.embed( model="embed-english-v3.0", texts=["Hello world"], input_type="search_document", embedding_types=["float"], ) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 1024) def test_embed_with_truncate(self): """Test embed with truncate parameter.""" long_text = "hello " * 1000 for mode in ["NONE", "START", "END"]: response = self.client.embed( model="embed-english-v3.0", texts=[long_text], input_type="search_document", truncate=mode, ) self.assertIsNotNone(response.embeddings.float_) self.assertEqual(len(response.embeddings.float_[0]), 1024) def test_command_r_plus_chat(self): """Test command-r-plus-08-2024 via V1 client.""" v1_client = cohere.OciClient( oci_region=os.getenv("OCI_REGION", "us-chicago-1"), oci_compartment_id=os.getenv("OCI_COMPARTMENT_ID"), oci_profile=os.getenv("OCI_PROFILE", "DEFAULT"), ) response = v1_client.chat( model="command-r-plus-08-2024", message="What is 2+2? Answer with just the number.", ) self.assertIsNotNone(response.text) self.assertIn("4", response.text) def test_v2_multi_turn_chat(self): """Test V2 chat with conversation history (multi-turn).""" response = self.client.chat( model="command-a-03-2025", messages=[ {"role": "user", "content": "My name is Alice."}, {"role": "assistant", "content": "Nice to meet you, Alice!"}, {"role": "user", "content": "What is my name?"}, ], ) self.assertIsNotNone(response.message) content = response.message.content[0].text self.assertIn("Alice", content) def test_v2_system_message(self): """Test V2 chat with a system message.""" response = self.client.chat( model="command-a-03-2025", messages=[ {"role": "system", "content": "You are a helpful assistant. Always respond in exactly 3 words."}, {"role": "user", "content": "Say hello."}, ], ) self.assertIsNotNone(response.message) self.assertIsNotNone(response.message.content[0].text) class TestOciClientTransformations(unittest.TestCase): """Unit tests for OCI request/response transformations (no OCI credentials required).""" def test_thinking_parameter_transformation(self): """Test that thinking parameter is correctly transformed to OCI format.""" from cohere.oci_client import transform_request_to_oci cohere_body = { "model": "command-a-reasoning-08-2025", "messages": [{"role": "user", "content": "What is 2+2?"}], "thinking": { "type": "enabled", "token_budget": 10000, }, } result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) # Verify thinking parameter is transformed with camelCase for OCI API chat_request = result["chatRequest"] self.assertIn("thinking", chat_request) self.assertEqual(chat_request["thinking"]["type"], "ENABLED") self.assertEqual(chat_request["thinking"]["tokenBudget"], 10000) # camelCase for OCI def test_thinking_parameter_disabled(self): """Test that disabled thinking is correctly transformed.""" from cohere.oci_client import transform_request_to_oci cohere_body = { "model": "command-a-reasoning-08-2025", "messages": [{"role": "user", "content": "Hello"}], "thinking": { "type": "disabled", }, } result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) chat_request = result["chatRequest"] self.assertIn("thinking", chat_request) self.assertEqual(chat_request["thinking"]["type"], "DISABLED") self.assertNotIn("token_budget", chat_request["thinking"]) def test_thinking_response_transformation(self): """Test that thinking content in response is correctly transformed.""" from cohere.oci_client import transform_oci_response_to_cohere oci_response = { "chatResponse": { "id": "test-id", "message": { "role": "ASSISTANT", "content": [ {"type": "THINKING", "thinking": "Let me think about this..."}, {"type": "TEXT", "text": "The answer is 4."}, ], }, "finishReason": "COMPLETE", "usage": {"inputTokens": 10, "completionTokens": 20}, } } result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) # Verify content types are lowercased self.assertEqual(result["message"]["content"][0]["type"], "thinking") self.assertEqual(result["message"]["content"][1]["type"], "text") def test_stream_event_thinking_transformation(self): """Test that thinking content in stream events is correctly transformed.""" from cohere.oci_client import transform_stream_event # OCI thinking event oci_event = { "message": { "content": [{"type": "THINKING", "thinking": "Reasoning step..."}] } } result = transform_stream_event("chat", oci_event, is_v2=True) self.assertEqual(result[0]["type"], "content-delta") self.assertIn("thinking", result[0]["delta"]["message"]["content"]) self.assertEqual(result[0]["delta"]["message"]["content"]["thinking"], "Reasoning step...") def test_stream_event_text_transformation(self): """Test that text content in stream events is correctly transformed.""" from cohere.oci_client import transform_stream_event # OCI text event oci_event = { "message": { "content": [{"type": "TEXT", "text": "The answer is..."}] } } result = transform_stream_event("chat", oci_event, is_v2=True) self.assertEqual(result[0]["type"], "content-delta") self.assertIn("text", result[0]["delta"]["message"]["content"]) self.assertEqual(result[0]["delta"]["message"]["content"]["text"], "The answer is...") def test_thinking_parameter_none(self): """Test that thinking=None does not crash (issue: null guard).""" from cohere.oci_client import transform_request_to_oci cohere_body = { "model": "command-a-03-2025", "messages": [{"role": "user", "content": "Hello"}], "thinking": None, # Explicitly set to None } # Should not crash with TypeError result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) chat_request = result["chatRequest"] # thinking should not be in request when None self.assertNotIn("thinking", chat_request) def test_v2_response_role_lowercased(self): """Test that V2 response message role is lowercased.""" from cohere.oci_client import transform_oci_response_to_cohere oci_response = { "chatResponse": { "id": "test-id", "message": { "role": "ASSISTANT", "content": [{"type": "TEXT", "text": "Hello"}], }, "finishReason": "COMPLETE", "usage": {"inputTokens": 10, "completionTokens": 20}, } } result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) # Role should be lowercased self.assertEqual(result["message"]["role"], "assistant") def test_v2_response_finish_reason_uppercase(self): """Test that V2 response finish_reason stays uppercase.""" from cohere.oci_client import transform_oci_response_to_cohere oci_response = { "chatResponse": { "id": "test-id", "message": { "role": "ASSISTANT", "content": [{"type": "TEXT", "text": "Hello"}], }, "finishReason": "MAX_TOKENS", "usage": {"inputTokens": 10, "completionTokens": 20}, } } result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) # V2 finish_reason should stay uppercase self.assertEqual(result["finish_reason"], "MAX_TOKENS") def test_v2_response_tool_calls_conversion(self): """Test that V2 response converts toolCalls to tool_calls.""" from cohere.oci_client import transform_oci_response_to_cohere oci_response = { "chatResponse": { "id": "test-id", "message": { "role": "ASSISTANT", "content": [{"type": "TEXT", "text": "I'll help with that."}], "toolCalls": [ { "id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": '{"city": "London"}'}, } ], }, "finishReason": "TOOL_CALL", "usage": {"inputTokens": 10, "completionTokens": 20}, } } result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) # toolCalls should be converted to tool_calls self.assertIn("tool_calls", result["message"]) self.assertNotIn("toolCalls", result["message"]) self.assertEqual(len(result["message"]["tool_calls"]), 1) self.assertEqual(result["message"]["tool_calls"][0]["id"], "call_123") def test_normalize_model_for_oci(self): """Test model name normalization for OCI.""" from cohere.oci_client import normalize_model_for_oci # Plain model name gets cohere. prefix self.assertEqual(normalize_model_for_oci("command-a-03-2025"), "cohere.command-a-03-2025") # Already prefixed passes through self.assertEqual(normalize_model_for_oci("cohere.embed-english-v3.0"), "cohere.embed-english-v3.0") # OCID passes through self.assertEqual( normalize_model_for_oci("ocid1.generativeaimodel.oc1.us-chicago-1.abc"), "ocid1.generativeaimodel.oc1.us-chicago-1.abc", ) def test_transform_embed_request(self): """Test embed request transformation to OCI format.""" from cohere.oci_client import transform_request_to_oci body = { "model": "embed-english-v3.0", "texts": ["hello", "world"], "input_type": "search_document", "truncate": "end", "embedding_types": ["float", "int8"], } result = transform_request_to_oci("embed", body, "compartment-123") self.assertEqual(result["inputs"], ["hello", "world"]) self.assertEqual(result["inputType"], "SEARCH_DOCUMENT") self.assertEqual(result["truncate"], "END") self.assertEqual(result["embeddingTypes"], ["float", "int8"]) self.assertEqual(result["compartmentId"], "compartment-123") self.assertEqual(result["servingMode"]["modelId"], "cohere.embed-english-v3.0") def test_transform_embed_request_with_optional_params(self): """Test embed request forwards optional params.""" from cohere.oci_client import transform_request_to_oci body = { "model": "embed-english-v3.0", "inputs": [{"content": [{"type": "text", "text": "hello"}]}], "input_type": "classification", "max_tokens": 256, "output_dimension": 512, "priority": 42, } result = transform_request_to_oci("embed", body, "compartment-123") self.assertEqual(result["inputs"], body["inputs"]) self.assertEqual(result["maxTokens"], 256) self.assertEqual(result["outputDimension"], 512) self.assertEqual(result["priority"], 42) def test_transform_embed_request_rejects_images(self): """Test embed request fails clearly for unsupported top-level images.""" from cohere.oci_client import transform_request_to_oci with self.assertRaises(ValueError) as ctx: transform_request_to_oci( "embed", { "model": "embed-english-v3.0", "images": ["data:image/png;base64,abc"], "input_type": "classification", }, "compartment-123", ) self.assertIn("top-level 'images' parameter", str(ctx.exception)) def test_transform_chat_request_optional_params(self): """Test chat request transformation includes optional params.""" from cohere.oci_client import transform_request_to_oci body = { "model": "command-a-03-2025", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 100, "temperature": 0.7, "stop_sequences": ["END"], "frequency_penalty": 0.5, "strict_tools": True, "response_format": {"type": "json_object"}, "logprobs": True, "tool_choice": "REQUIRED", "priority": 7, } result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True) chat_req = result["chatRequest"] self.assertEqual(chat_req["maxTokens"], 100) self.assertEqual(chat_req["temperature"], 0.7) self.assertEqual(chat_req["stopSequences"], ["END"]) self.assertEqual(chat_req["frequencyPenalty"], 0.5) self.assertTrue(chat_req["strictTools"]) self.assertEqual(chat_req["responseFormat"], {"type": "json_object"}) self.assertTrue(chat_req["logprobs"]) self.assertEqual(chat_req["toolChoice"], "REQUIRED") self.assertEqual(chat_req["priority"], 7) def test_v2_client_rejects_v1_request(self): """Test OciClientV2 fails when given V1-style 'message' string.""" from cohere.oci_client import transform_request_to_oci with self.assertRaises(ValueError) as ctx: transform_request_to_oci( "chat", {"model": "command-a-03-2025", "message": "Hello"}, "compartment-123", is_v2=True, ) self.assertIn("OciClientV2", str(ctx.exception)) def test_v1_client_rejects_v2_request(self): """Test OciClient fails when given V2-style 'messages' array.""" from cohere.oci_client import transform_request_to_oci with self.assertRaises(ValueError) as ctx: transform_request_to_oci( "chat", {"model": "command-r-08-2024", "messages": [{"role": "user", "content": "Hi"}]}, "compartment-123", is_v2=False, ) self.assertIn("OciClient ", str(ctx.exception)) def test_unsupported_endpoint_raises(self): """Test that transform_request_to_oci raises for unsupported endpoints.""" from cohere.oci_client import transform_request_to_oci with self.assertRaises(ValueError) as ctx: transform_request_to_oci("rerank", {"model": "rerank-v3.5"}, "compartment-123") self.assertIn("rerank", str(ctx.exception)) self.assertIn("not supported", str(ctx.exception)) def test_v1_chat_request_optional_params(self): """Test V1 chat request forwards supported optional params.""" from cohere.oci_client import transform_request_to_oci body = { "model": "command-r-08-2024", "message": "Hi", "max_tokens": 100, "temperature": 0.7, "k": 10, "p": 0.8, "seed": 123, "stop_sequences": ["END"], "frequency_penalty": 0.5, "presence_penalty": 0.2, "documents": [{"title": "Doc", "text": "Body"}], "tools": [{"name": "lookup"}], "tool_results": [{"call": {"name": "lookup"}}], "response_format": {"type": "json_object"}, "safety_mode": "NONE", "priority": 4, } result = transform_request_to_oci("chat", body, "compartment-123", is_v2=False) chat_req = result["chatRequest"] self.assertEqual(chat_req["apiFormat"], "COHERE") self.assertEqual(chat_req["message"], "Hi") self.assertEqual(chat_req["maxTokens"], 100) self.assertEqual(chat_req["temperature"], 0.7) self.assertEqual(chat_req["topK"], 10) self.assertEqual(chat_req["topP"], 0.8) self.assertEqual(chat_req["seed"], 123) self.assertEqual(chat_req["frequencyPenalty"], 0.5) self.assertEqual(chat_req["presencePenalty"], 0.2) self.assertEqual(chat_req["priority"], 4) def test_v1_stream_wrapper_preserves_finish_reason(self): """Test V1 stream-end uses the OCI finish reason from the final event.""" import json from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: {"text": "Hello", "isFinished": false}\n', b'data: {"text": " world", "isFinished": true, "finishReason": "MAX_TOKENS"}\n', b"data: [DONE]\n", ] events = [ json.loads(raw.decode("utf-8")) for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=False) ] # First event should be stream-start with generation_id self.assertEqual(events[0]["event_type"], "stream-start") self.assertIn("generation_id", events[0]) self.assertEqual(events[3]["event_type"], "stream-end") self.assertEqual(events[3]["finish_reason"], "MAX_TOKENS") self.assertEqual(events[3]["response"]["text"], "Hello world") def test_transform_chat_request_tool_message_fields(self): """Test tool message fields are converted to OCI names.""" from cohere.oci_client import transform_request_to_oci body = { "model": "command-a-03-2025", "messages": [ { "role": "assistant", "content": [{"type": "text", "text": "Use tool"}], "tool_calls": [{"id": "call_1"}], "tool_plan": "Plan", }, { "role": "tool", "tool_call_id": "call_1", "content": [{"type": "text", "text": "Result"}], }, ], } result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True) assistant_message, tool_message = result["chatRequest"]["messages"] self.assertEqual(assistant_message["toolCalls"], [{"id": "call_1"}]) self.assertEqual(assistant_message["toolPlan"], "Plan") self.assertEqual(tool_message["toolCallId"], "call_1") def test_get_oci_url_known_endpoints(self): """Test URL generation for known endpoints.""" from cohere.oci_client import get_oci_url url = get_oci_url("us-chicago-1", "embed") self.assertIn("/actions/embedText", url) url = get_oci_url("us-chicago-1", "chat") self.assertIn("/actions/chat", url) def test_get_oci_url_unknown_endpoint_raises(self): """Test that unknown endpoints raise ValueError instead of producing bad URLs.""" from cohere.oci_client import get_oci_url with self.assertRaises(ValueError) as ctx: get_oci_url("us-chicago-1", "unknown_endpoint") self.assertIn("not supported", str(ctx.exception)) def test_load_oci_config_missing_private_key_raises(self): """Test that direct credentials without private key raises clear error.""" from cohere.oci_client import _load_oci_config with patch("cohere.oci_client.lazy_oci", return_value=MagicMock()): with self.assertRaises(ValueError) as ctx: _load_oci_config( auth_type="api_key", config_path=None, profile=None, user_id="ocid1.user.oc1...", fingerprint="xx:xx:xx", tenancy_id="ocid1.tenancy.oc1...", # No private_key_path or private_key_content ) self.assertIn("oci_private_key_path", str(ctx.exception)) def test_load_oci_config_ignores_inherited_session_auth(self): """Test that named API-key profiles do not inherit DEFAULT session auth fields.""" from cohere.oci_client import _load_oci_config config_text = """ [DEFAULT] security_token_file=/tmp/default-token [API_KEY_AUTH] user=ocid1.user.oc1..test fingerprint=aa:bb key_file=/tmp/test.pem tenancy=ocid1.tenancy.oc1..test region=us-chicago-1 """.strip() with tempfile.NamedTemporaryFile("w", delete=False) as config_file: config_file.write(config_text) config_path = config_file.name try: mock_oci = MagicMock() mock_oci.config.from_file.return_value = { "user": "ocid1.user.oc1..test", "fingerprint": "aa:bb", "key_file": "/tmp/test.pem", "tenancy": "ocid1.tenancy.oc1..test", "region": "us-chicago-1", "security_token_file": "/tmp/default-token", } with patch("cohere.oci_client.lazy_oci", return_value=mock_oci): config = _load_oci_config( auth_type="api_key", config_path=config_path, profile="API_KEY_AUTH", ) finally: os.unlink(config_path) self.assertNotIn("security_token_file", config) def test_session_auth_prefers_security_token_signer(self): """Test session-based auth uses SecurityTokenSigner before API key signer.""" from cohere.oci_client import map_request_to_oci mock_oci = MagicMock() mock_security_signer = MagicMock() mock_oci.signer.load_private_key_from_file.return_value = "private-key" mock_oci.auth.signers.SecurityTokenSigner.return_value = mock_security_signer with patch("cohere.oci_client.lazy_oci", return_value=mock_oci), patch( "builtins.open", mock_open(read_data="session-token") ): hook = map_request_to_oci( oci_config={ "user": "ocid1.user.oc1..example", "fingerprint": "xx:xx", "tenancy": "ocid1.tenancy.oc1..example", "security_token_file": "~/.oci/token", "key_file": "~/.oci/key.pem", }, oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1..example", ) request = MagicMock() request.url.path = "/v2/embed" request.read.return_value = b'{"model":"embed-english-v3.0","texts":["hello"]}' request.method = "POST" request.extensions = {} hook(request) # SecurityTokenSigner is called at least once (init) and again per request # (token file is re-read on each signing call to pick up refreshed tokens). mock_oci.auth.signers.SecurityTokenSigner.assert_called_with( token="session-token", private_key="private-key", ) self.assertGreaterEqual(mock_oci.auth.signers.SecurityTokenSigner.call_count, 1) mock_oci.signer.Signer.assert_not_called() def test_session_token_refreshed_on_subsequent_requests(self): """Verify the refreshing signer picks up a new token written to the token file.""" import tempfile import os from cohere.oci_client import map_request_to_oci mock_oci = MagicMock() mock_oci.signer.load_private_key_from_file.return_value = "private-key" # Write initial token to a real temp file so we can overwrite it later. with tempfile.NamedTemporaryFile("w", suffix=".token", delete=False) as tf: tf.write("token-v1") token_path = tf.name try: with patch("cohere.oci_client.lazy_oci", return_value=mock_oci): hook = map_request_to_oci( oci_config={ "security_token_file": token_path, "key_file": "/irrelevant.pem", }, oci_region="us-chicago-1", oci_compartment_id="ocid1.compartment.oc1..example", ) def _make_request(): req = MagicMock() req.url.path = "/v2/embed" req.read.return_value = b'{"model":"embed-english-v3.0","texts":["hi"]}' req.method = "POST" req.extensions = {} return req # First request uses token-v1 hook(_make_request()) calls_after_first = mock_oci.auth.signers.SecurityTokenSigner.call_count # Simulate token refresh by overwriting the file with open(token_path, "w") as _f: _f.write("token-v2") # Second request — should re-read and use token-v2 hook(_make_request()) self.assertGreater( mock_oci.auth.signers.SecurityTokenSigner.call_count, calls_after_first, "SecurityTokenSigner should be re-instantiated after token file update", ) # Verify the latest call used the refreshed token all_calls = mock_oci.auth.signers.SecurityTokenSigner.call_args_list last_call = all_calls[-1] last_token = last_call.kwargs.get("token") or (last_call.args[0] if last_call.args else None) self.assertEqual(last_token, "token-v2", "Last signing call must use the refreshed token") finally: os.unlink(token_path) def test_embed_response_lowercases_embedding_keys(self): """Test embed response uses lowercase keys expected by the SDK model.""" from cohere.oci_client import transform_oci_response_to_cohere result = transform_oci_response_to_cohere( "embed", { "id": "embed-id", "embeddings": {"FLOAT": [[0.1, 0.2]], "INT8": [[1, 2]]}, "usage": {"inputTokens": 3, "completionTokens": 7}, }, is_v2=True, ) self.assertIn("float", result["embeddings"]) self.assertIn("int8", result["embeddings"]) self.assertNotIn("FLOAT", result["embeddings"]) self.assertEqual(result["meta"]["tokens"]["output_tokens"], 7) def test_embed_response_includes_response_type_v1(self): """Test V1 embed response includes response_type=embeddings_floats for SDK union.""" from cohere.oci_client import transform_oci_response_to_cohere result = transform_oci_response_to_cohere( "embed", { "id": "embed-id", "embeddings": [[0.1, 0.2]], "usage": {"inputTokens": 3, "completionTokens": 0}, }, is_v2=False, ) self.assertEqual(result["response_type"], "embeddings_floats") def test_embed_response_includes_response_type_v2(self): """Test V2 embed response includes response_type=embeddings_by_type for SDK union.""" from cohere.oci_client import transform_oci_response_to_cohere result = transform_oci_response_to_cohere( "embed", { "id": "embed-id", "embeddings": {"FLOAT": [[0.1, 0.2]]}, "usage": {"inputTokens": 3, "completionTokens": 0}, }, is_v2=True, ) self.assertEqual(result["response_type"], "embeddings_by_type") def test_normalize_model_for_oci_rejects_empty_model(self): """Test model normalization fails clearly for empty model names.""" from cohere.oci_client import normalize_model_for_oci with self.assertRaises(ValueError) as ctx: normalize_model_for_oci("") self.assertIn("non-empty model", str(ctx.exception)) def test_stream_wrapper_emits_full_event_lifecycle(self): """Test that stream emits message-start, content-start, content-delta, content-end, message-end.""" import json from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n', b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n', b'data: [DONE]\n', ] events = [] for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): line = raw.decode("utf-8").strip() if line.startswith("data: "): events.append(json.loads(line[6:])) event_types = [e["type"] for e in events] self.assertEqual(event_types[0], "message-start") self.assertEqual(event_types[1], "content-start") self.assertEqual(event_types[2], "content-delta") self.assertEqual(event_types[3], "content-delta") self.assertEqual(event_types[4], "content-end") self.assertEqual(event_types[5], "message-end") # Verify message-start has id and role self.assertIn("id", events[0]) self.assertEqual(events[0]["delta"]["message"]["role"], "assistant") # Verify content-start has index and type self.assertEqual(events[1]["index"], 0) self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text") self.assertEqual(events[5]["delta"]["finish_reason"], "COMPLETE") def test_stream_wrapper_emits_new_content_block_on_thinking_transition(self): """Test streams emit a new content block when transitioning from thinking to text.""" import json from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}, "finishReason": "COMPLETE"}\n', b"data: [DONE]\n", ] events = [] for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): line = raw.decode("utf-8").strip() if line.startswith("data: "): events.append(json.loads(line[6:])) self.assertEqual(events[1]["type"], "content-start") self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "thinking") self.assertEqual(events[2]["type"], "content-delta") self.assertEqual(events[2]["index"], 0) self.assertEqual(events[3], {"type": "content-end", "index": 0}) self.assertEqual(events[4]["type"], "content-start") self.assertEqual(events[4]["index"], 1) self.assertEqual(events[4]["delta"]["message"]["content"]["type"], "text") self.assertEqual(events[5]["type"], "content-delta") self.assertEqual(events[5]["index"], 1) def test_stream_wrapper_no_spurious_block_on_finish_only_event(self): """Finish-only event after thinking block must not open a spurious empty text block.""" import json from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', b'data: {"finishReason": "COMPLETE"}\n', b"data: [DONE]\n", ] events = [] for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): line = raw.decode("utf-8").strip() if line.startswith("data: "): events.append(json.loads(line[6:])) types = [e["type"] for e in events] # Must not contain two content-start events self.assertEqual(types.count("content-start"), 1) # The single content block must be thinking cs = next(e for e in events if e["type"] == "content-start") self.assertEqual(cs["delta"]["message"]["content"]["type"], "thinking") # Must end cleanly self.assertEqual(events[-1]["type"], "message-end") def test_stream_wrapper_skips_malformed_json_with_warning(self): """Test that malformed JSON in SSE stream is skipped.""" from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: not-valid-json\n', b'data: {"message": {"content": [{"type": "TEXT", "text": "hello"}]}}\n', b'data: [DONE]\n', ] events = list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True)) # Should get message-start + content-start + content-delta + content-end + message-end. self.assertEqual(len(events), 5) def test_stream_wrapper_skips_message_end_for_empty_stream(self): """Test empty streams do not emit message-end without a preceding message-start.""" from cohere.oci_client import transform_oci_stream_wrapper events = list(transform_oci_stream_wrapper(iter([b"data: [DONE]\n"]), "chat", is_v2=True)) self.assertEqual(events, []) def test_stream_wrapper_done_uses_current_content_index_after_transition(self): """Test fallback content-end uses the latest content index after type transitions.""" import json from cohere.oci_client import transform_oci_stream_wrapper chunks = [ b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}}\n', b"data: [DONE]\n", ] events = [] for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): line = raw.decode("utf-8").strip() if line.startswith("data: "): events.append(json.loads(line[6:])) self.assertEqual(events[-2], {"type": "content-end", "index": 1}) self.assertEqual(events[-1]["type"], "message-end") def test_stream_wrapper_raises_on_transform_error(self): """Test that transform errors in stream produce OCI-specific error.""" from cohere.oci_client import transform_oci_stream_wrapper # Event with structure that will cause transform_stream_event to fail # (message is None, causing TypeError on "content" in None) chunks = [ b'data: {"message": null}\n', ] with self.assertRaises(RuntimeError) as ctx: list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True)) self.assertIn("OCI stream event transformation failed", str(ctx.exception)) def test_stream_event_finish_reason_keeps_final_text(self): """Test finish events keep final text before content-end.""" from cohere.oci_client import transform_stream_event events = transform_stream_event( "chat", { "message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE", }, is_v2=True, ) self.assertEqual(events[0]["type"], "content-delta") self.assertEqual(events[0]["delta"]["message"]["content"]["text"], " world") self.assertEqual(events[1]["type"], "content-end") if __name__ == "__main__": unittest.main() ================================================ FILE: tests/test_oci_mypy.py ================================================ """Mypy type-checking gate for OCI client code. Runs mypy on OCI source and test files and fails if any type errors are found. This prevents type regressions from being introduced silently. Run with: pytest tests/test_oci_mypy.py """ import os import shutil import subprocess import unittest MYPY_BIN = shutil.which("mypy") # Files that must stay mypy-clean OCI_SOURCE_FILES = [ "src/cohere/oci_client.py", "src/cohere/manually_maintained/lazy_oci_deps.py", ] OCI_TEST_FILES = [ "tests/test_oci_client.py", ] # --follow-imports=silent prevents mypy from crawling into transitive # dependencies (e.g. the AWS client) that have pre-existing errors. _MYPY_BASE = [ "--config-file", "mypy.ini", "--follow-imports=silent", ] def _run_mypy(files: list[str], extra_env: dict[str, str] | None = None) -> tuple[int, str]: """Run mypy on the given files and return (exit_code, output).""" assert MYPY_BIN is not None env = {**os.environ, **(extra_env or {})} result = subprocess.run( [MYPY_BIN, *_MYPY_BASE, *files], capture_output=True, text=True, env=env, ) return result.returncode, (result.stdout + result.stderr).strip() @unittest.skipIf(MYPY_BIN is None, "mypy not found on PATH") class TestOciMypy(unittest.TestCase): """Ensure OCI files pass mypy with no new errors.""" def test_oci_source_types(self): """OCI source files must be free of mypy errors.""" code, output = _run_mypy(OCI_SOURCE_FILES) self.assertEqual(code, 0, f"mypy found type errors in OCI source:\n{output}") def test_oci_test_types(self): """OCI test files must be free of mypy errors.""" # PYTHONPATH=src so mypy can resolve `import cohere` code, output = _run_mypy(OCI_TEST_FILES, extra_env={"PYTHONPATH": "src"}) self.assertEqual(code, 0, f"mypy found type errors in OCI tests:\n{output}") if __name__ == "__main__": unittest.main() ================================================ FILE: tests/test_overrides.py ================================================ import unittest from contextlib import redirect_stderr import logging from cohere import EmbedByTypeResponseEmbeddings LOGGER = logging.getLogger(__name__) class TestClient(unittest.TestCase): def test_float_alias(self) -> None: embeds = EmbedByTypeResponseEmbeddings(float_=[[1.0]]) self.assertEqual(embeds.float_, [[1.0]]) self.assertEqual(embeds.float, [[1.0]]) # type: ignore