Repository: cohere-ai/cohere-python
Branch: main
Commit: 756b1d8ec0e4
Files: 367
Total size: 1.9 MB
Directory structure:
gitextract_zznh6amy/
├── .fern/
│ └── metadata.json
├── .fernignore
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── improvement_request.md
│ └── workflows/
│ └── ci.yml
├── .gitignore
├── 4.0.0-5.0.0-migration-guide.md
├── LICENSE
├── README.md
├── mypy.ini
├── pyproject.toml
├── reference.md
├── requirements.txt
├── src/
│ └── cohere/
│ ├── __init__.py
│ ├── _default_clients.py
│ ├── aliases.py
│ ├── audio/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── transcriptions/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── types/
│ │ ├── __init__.py
│ │ └── audio_transcriptions_create_response.py
│ ├── aws_client.py
│ ├── base_client.py
│ ├── batches/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── types/
│ │ ├── __init__.py
│ │ ├── batch.py
│ │ ├── batch_status.py
│ │ ├── cancel_batch_response.py
│ │ ├── create_batch_response.py
│ │ ├── get_batch_response.py
│ │ └── list_batches_response.py
│ ├── bedrock_client.py
│ ├── client.py
│ ├── client_v2.py
│ ├── config.py
│ ├── connectors/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ └── raw_client.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── api_error.py
│ │ ├── client_wrapper.py
│ │ ├── datetime_utils.py
│ │ ├── file.py
│ │ ├── force_multipart.py
│ │ ├── http_client.py
│ │ ├── http_response.py
│ │ ├── http_sse/
│ │ │ ├── __init__.py
│ │ │ ├── _api.py
│ │ │ ├── _decoders.py
│ │ │ ├── _exceptions.py
│ │ │ └── _models.py
│ │ ├── jsonable_encoder.py
│ │ ├── logging.py
│ │ ├── parse_error.py
│ │ ├── pydantic_utilities.py
│ │ ├── query_encoder.py
│ │ ├── remove_none_from_dict.py
│ │ ├── request_options.py
│ │ ├── serialization.py
│ │ └── unchecked_base_model.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── types/
│ │ ├── __init__.py
│ │ ├── datasets_create_response.py
│ │ ├── datasets_get_response.py
│ │ ├── datasets_get_usage_response.py
│ │ └── datasets_list_response.py
│ ├── embed_jobs/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── types/
│ │ ├── __init__.py
│ │ └── create_embed_job_request_truncate.py
│ ├── environment.py
│ ├── errors/
│ │ ├── __init__.py
│ │ ├── bad_request_error.py
│ │ ├── client_closed_request_error.py
│ │ ├── forbidden_error.py
│ │ ├── gateway_timeout_error.py
│ │ ├── internal_server_error.py
│ │ ├── invalid_token_error.py
│ │ ├── not_found_error.py
│ │ ├── not_implemented_error.py
│ │ ├── service_unavailable_error.py
│ │ ├── too_many_requests_error.py
│ │ ├── unauthorized_error.py
│ │ └── unprocessable_entity_error.py
│ ├── finetuning/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── finetuning/
│ │ │ ├── __init__.py
│ │ │ └── types/
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── base_type.py
│ │ │ ├── create_finetuned_model_response.py
│ │ │ ├── delete_finetuned_model_response.py
│ │ │ ├── event.py
│ │ │ ├── finetuned_model.py
│ │ │ ├── get_finetuned_model_response.py
│ │ │ ├── hyperparameters.py
│ │ │ ├── list_events_response.py
│ │ │ ├── list_finetuned_models_response.py
│ │ │ ├── list_training_step_metrics_response.py
│ │ │ ├── lora_target_modules.py
│ │ │ ├── settings.py
│ │ │ ├── status.py
│ │ │ ├── strategy.py
│ │ │ ├── training_step_metrics.py
│ │ │ ├── update_finetuned_model_response.py
│ │ │ └── wandb_config.py
│ │ └── raw_client.py
│ ├── manually_maintained/
│ │ ├── __init__.py
│ │ ├── cache.py
│ │ ├── cohere_aws/
│ │ │ ├── __init__.py
│ │ │ ├── chat.py
│ │ │ ├── classification.py
│ │ │ ├── client.py
│ │ │ ├── embeddings.py
│ │ │ ├── error.py
│ │ │ ├── generation.py
│ │ │ ├── mode.py
│ │ │ ├── rerank.py
│ │ │ ├── response.py
│ │ │ └── summary.py
│ │ ├── lazy_aws_deps.py
│ │ ├── lazy_oci_deps.py
│ │ ├── streaming_embed.py
│ │ └── tokenizers.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ └── raw_client.py
│ ├── oci_client.py
│ ├── overrides.py
│ ├── py.typed
│ ├── raw_base_client.py
│ ├── sagemaker_client.py
│ ├── types/
│ │ ├── __init__.py
│ │ ├── api_meta.py
│ │ ├── api_meta_api_version.py
│ │ ├── api_meta_billed_units.py
│ │ ├── api_meta_tokens.py
│ │ ├── assistant_message.py
│ │ ├── assistant_message_response.py
│ │ ├── assistant_message_response_content_item.py
│ │ ├── assistant_message_v2content.py
│ │ ├── assistant_message_v2content_one_item.py
│ │ ├── auth_token_type.py
│ │ ├── chat_citation.py
│ │ ├── chat_citation_generation_event.py
│ │ ├── chat_citation_type.py
│ │ ├── chat_connector.py
│ │ ├── chat_content_delta_event.py
│ │ ├── chat_content_delta_event_delta.py
│ │ ├── chat_content_delta_event_delta_message.py
│ │ ├── chat_content_delta_event_delta_message_content.py
│ │ ├── chat_content_end_event.py
│ │ ├── chat_content_start_event.py
│ │ ├── chat_content_start_event_delta.py
│ │ ├── chat_content_start_event_delta_message.py
│ │ ├── chat_content_start_event_delta_message_content.py
│ │ ├── chat_content_start_event_delta_message_content_type.py
│ │ ├── chat_data_metrics.py
│ │ ├── chat_debug_event.py
│ │ ├── chat_document.py
│ │ ├── chat_document_source.py
│ │ ├── chat_finish_reason.py
│ │ ├── chat_message.py
│ │ ├── chat_message_end_event.py
│ │ ├── chat_message_end_event_delta.py
│ │ ├── chat_message_start_event.py
│ │ ├── chat_message_start_event_delta.py
│ │ ├── chat_message_start_event_delta_message.py
│ │ ├── chat_message_v2.py
│ │ ├── chat_messages.py
│ │ ├── chat_request_citation_quality.py
│ │ ├── chat_request_prompt_truncation.py
│ │ ├── chat_request_safety_mode.py
│ │ ├── chat_search_queries_generation_event.py
│ │ ├── chat_search_query.py
│ │ ├── chat_search_result.py
│ │ ├── chat_search_result_connector.py
│ │ ├── chat_search_results_event.py
│ │ ├── chat_stream_end_event.py
│ │ ├── chat_stream_end_event_finish_reason.py
│ │ ├── chat_stream_event.py
│ │ ├── chat_stream_event_type.py
│ │ ├── chat_stream_request_citation_quality.py
│ │ ├── chat_stream_request_prompt_truncation.py
│ │ ├── chat_stream_request_safety_mode.py
│ │ ├── chat_stream_start_event.py
│ │ ├── chat_text_content.py
│ │ ├── chat_text_generation_event.py
│ │ ├── chat_text_response_format.py
│ │ ├── chat_text_response_format_v2.py
│ │ ├── chat_thinking_content.py
│ │ ├── chat_tool_call_delta_event.py
│ │ ├── chat_tool_call_delta_event_delta.py
│ │ ├── chat_tool_call_delta_event_delta_message.py
│ │ ├── chat_tool_call_delta_event_delta_message_tool_calls.py
│ │ ├── chat_tool_call_delta_event_delta_message_tool_calls_function.py
│ │ ├── chat_tool_call_end_event.py
│ │ ├── chat_tool_call_start_event.py
│ │ ├── chat_tool_call_start_event_delta.py
│ │ ├── chat_tool_call_start_event_delta_message.py
│ │ ├── chat_tool_calls_chunk_event.py
│ │ ├── chat_tool_calls_generation_event.py
│ │ ├── chat_tool_message.py
│ │ ├── chat_tool_plan_delta_event.py
│ │ ├── chat_tool_plan_delta_event_delta.py
│ │ ├── chat_tool_plan_delta_event_delta_message.py
│ │ ├── chat_tool_source.py
│ │ ├── check_api_key_response.py
│ │ ├── citation.py
│ │ ├── citation_end_event.py
│ │ ├── citation_options.py
│ │ ├── citation_options_mode.py
│ │ ├── citation_start_event.py
│ │ ├── citation_start_event_delta.py
│ │ ├── citation_start_event_delta_message.py
│ │ ├── citation_type.py
│ │ ├── classify_data_metrics.py
│ │ ├── classify_example.py
│ │ ├── classify_request_truncate.py
│ │ ├── classify_response.py
│ │ ├── classify_response_classifications_item.py
│ │ ├── classify_response_classifications_item_classification_type.py
│ │ ├── classify_response_classifications_item_labels_value.py
│ │ ├── compatible_endpoint.py
│ │ ├── connector.py
│ │ ├── connector_auth_status.py
│ │ ├── connector_o_auth.py
│ │ ├── content.py
│ │ ├── create_connector_o_auth.py
│ │ ├── create_connector_response.py
│ │ ├── create_connector_service_auth.py
│ │ ├── create_embed_job_response.py
│ │ ├── dataset.py
│ │ ├── dataset_part.py
│ │ ├── dataset_type.py
│ │ ├── dataset_validation_status.py
│ │ ├── delete_connector_response.py
│ │ ├── detokenize_response.py
│ │ ├── document.py
│ │ ├── document_content.py
│ │ ├── embed_by_type_response.py
│ │ ├── embed_by_type_response_embeddings.py
│ │ ├── embed_by_type_response_response_type.py
│ │ ├── embed_content.py
│ │ ├── embed_floats_response.py
│ │ ├── embed_image.py
│ │ ├── embed_image_url.py
│ │ ├── embed_input.py
│ │ ├── embed_input_type.py
│ │ ├── embed_job.py
│ │ ├── embed_job_status.py
│ │ ├── embed_job_truncate.py
│ │ ├── embed_request_truncate.py
│ │ ├── embed_response.py
│ │ ├── embed_text.py
│ │ ├── embedding_type.py
│ │ ├── finetune_dataset_metrics.py
│ │ ├── finish_reason.py
│ │ ├── generate_request_return_likelihoods.py
│ │ ├── generate_request_truncate.py
│ │ ├── generate_stream_end.py
│ │ ├── generate_stream_end_response.py
│ │ ├── generate_stream_error.py
│ │ ├── generate_stream_event.py
│ │ ├── generate_stream_request_return_likelihoods.py
│ │ ├── generate_stream_request_truncate.py
│ │ ├── generate_stream_text.py
│ │ ├── generate_streamed_response.py
│ │ ├── generation.py
│ │ ├── get_connector_response.py
│ │ ├── get_model_response.py
│ │ ├── get_model_response_sampling_defaults.py
│ │ ├── image.py
│ │ ├── image_content.py
│ │ ├── image_url.py
│ │ ├── image_url_detail.py
│ │ ├── json_response_format.py
│ │ ├── json_response_format_v2.py
│ │ ├── label_metric.py
│ │ ├── list_connectors_response.py
│ │ ├── list_embed_job_response.py
│ │ ├── list_models_response.py
│ │ ├── logprob_item.py
│ │ ├── message.py
│ │ ├── metrics.py
│ │ ├── non_streamed_chat_response.py
│ │ ├── o_auth_authorize_response.py
│ │ ├── parse_info.py
│ │ ├── rerank_document.py
│ │ ├── rerank_request_documents_item.py
│ │ ├── rerank_response.py
│ │ ├── rerank_response_results_item.py
│ │ ├── rerank_response_results_item_document.py
│ │ ├── reranker_data_metrics.py
│ │ ├── response_format.py
│ │ ├── response_format_v2.py
│ │ ├── single_generation.py
│ │ ├── single_generation_in_stream.py
│ │ ├── single_generation_token_likelihoods_item.py
│ │ ├── source.py
│ │ ├── streamed_chat_response.py
│ │ ├── summarize_request_extractiveness.py
│ │ ├── summarize_request_format.py
│ │ ├── summarize_request_length.py
│ │ ├── summarize_response.py
│ │ ├── system_message_v2.py
│ │ ├── system_message_v2content.py
│ │ ├── system_message_v2content_one_item.py
│ │ ├── thinking.py
│ │ ├── thinking_type.py
│ │ ├── tokenize_response.py
│ │ ├── tool.py
│ │ ├── tool_call.py
│ │ ├── tool_call_delta.py
│ │ ├── tool_call_v2.py
│ │ ├── tool_call_v2function.py
│ │ ├── tool_content.py
│ │ ├── tool_message_v2.py
│ │ ├── tool_message_v2content.py
│ │ ├── tool_parameter_definitions_value.py
│ │ ├── tool_result.py
│ │ ├── tool_v2.py
│ │ ├── tool_v2function.py
│ │ ├── update_connector_response.py
│ │ ├── usage.py
│ │ ├── usage_billed_units.py
│ │ ├── usage_tokens.py
│ │ ├── user_message_v2.py
│ │ └── user_message_v2content.py
│ ├── utils.py
│ ├── v2/
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── raw_client.py
│ │ └── types/
│ │ ├── __init__.py
│ │ ├── v2chat_request_documents_item.py
│ │ ├── v2chat_request_safety_mode.py
│ │ ├── v2chat_request_tool_choice.py
│ │ ├── v2chat_response.py
│ │ ├── v2chat_stream_request_documents_item.py
│ │ ├── v2chat_stream_request_safety_mode.py
│ │ ├── v2chat_stream_request_tool_choice.py
│ │ ├── v2chat_stream_response.py
│ │ ├── v2embed_request_truncate.py
│ │ ├── v2rerank_response.py
│ │ └── v2rerank_response_results_item.py
│ └── version.py
└── tests/
├── __init__.py
├── embed_job.jsonl
├── test_async_client.py
├── test_aws_client_unit.py
├── test_bedrock_client.py
├── test_client.py
├── test_client_init.py
├── test_client_v2.py
├── test_embed_streaming.py
├── test_embed_utils.py
├── test_oci_client.py
├── test_oci_mypy.py
└── test_overrides.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .fern/metadata.json
================================================
{
"cliVersion": "4.63.2",
"generatorName": "fernapi/fern-python-sdk",
"generatorVersion": "5.3.3",
"generatorConfig": {
"inline_request_params": false,
"extras": {
"oci": [
"oci"
]
},
"extra_dependencies": {
"fastavro": "^1.9.4",
"requests": "^2.0.0",
"types-requests": "^2.0.0",
"tokenizers": ">=0.15,<1",
"oci": {
"version": "^2.165.0",
"optional": true
}
},
"improved_imports": true,
"pydantic_config": {
"frozen": false,
"union_naming": "v1",
"require_optional_fields": false,
"extra_fields": "allow",
"use_str_enums": true,
"skip_validation": true
},
"timeout_in_seconds": 300,
"client": {
"class_name": "BaseCohere",
"filename": "base_client.py",
"exported_class_name": "Client",
"exported_filename": "client.py"
},
"additional_init_exports": [
{
"from": "client",
"imports": [
"Client",
"AsyncClient"
]
},
{
"from": "bedrock_client",
"imports": [
"BedrockClient",
"BedrockClientV2"
]
},
{
"from": "sagemaker_client",
"imports": [
"SagemakerClient",
"SagemakerClientV2"
]
},
{
"from": "aws_client",
"imports": [
"AwsClient"
]
},
{
"from": "oci_client",
"imports": [
"OciClient",
"OciClientV2"
]
},
{
"from": "client_v2",
"imports": [
"AsyncClientV2",
"ClientV2"
]
},
{
"from": "aliases",
"imports": [
"StreamedChatResponseV2",
"MessageStartStreamedChatResponseV2",
"MessageEndStreamedChatResponseV2",
"ContentStartStreamedChatResponseV2",
"ContentDeltaStreamedChatResponseV2",
"ContentEndStreamedChatResponseV2",
"ToolCallStartStreamedChatResponseV2",
"ToolCallDeltaStreamedChatResponseV2",
"ToolCallEndStreamedChatResponseV2",
"ChatResponse"
]
}
]
},
"originGitCommit": "8dfb5e03f14a05967c4cdeeb44429eb4c1dca198",
"sdkVersion": "6.1.0"
}
================================================
FILE: .fernignore
================================================
4.0.0-5.0.0-migration-guide.md
banner.png
README.md
src/cohere/client.py
tests
.github/workflows/ci.yml
.github/ISSUE_TEMPLATE
LICENSE
.github/workflows/tests.yml
src/cohere/utils.py
src/cohere/overrides.py
src/cohere/config.py
src/cohere/manually_maintained
src/cohere/manually_maintained/__init__.py
src/cohere/bedrock_client.py
src/cohere/aws_client.py
src/cohere/sagemaker_client.py
src/cohere/oci_client.py
src/cohere/client_v2.py
mypy.ini
src/cohere/aliases.py
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report related to an SDK error
about: Create a report to help us improve
title: ''
labels: ''
---
**SDK Version (required)**
Provide the version you are using. To get the version, run the following python snippet
```python
import cohere
print(cohere.__version__) # 5.6.1
```
**Describe the bug**
A clear and concise description of what the bug is.
**Screenshots**
If applicable, add screenshots to help explain your problem.
================================================
FILE: .github/ISSUE_TEMPLATE/improvement_request.md
================================================
---
name: Improvement request, or addition features
about: Create a request to help us improve
title: ""
labels: ""
---
**Describe the improvement**
A clear and concise description of what the new improvement is.
**Code snippet of expected outcome**
If applicable, add a code snippet of how you'd like to see the feature implemented
================================================
FILE: .github/workflows/ci.yml
================================================
name: ci
on: [push]
jobs:
compile:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3
- name: Set up python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Bootstrap poetry
uses: snok/install-poetry@v1
with:
version: 1.5.1
virtualenvs-in-project: false
- name: Install dependencies
run: poetry install
- name: Compile
run: poetry run mypy .
test:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3
- name: Set up python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Bootstrap poetry
uses: snok/install-poetry@v1
with:
version: 1.5.1
virtualenvs-in-project: false
- name: Install dependencies
run: poetry install
- name: Install aws deps
run: poetry run pip install boto3 sagemaker botocore
- name: Test
run: poetry run pytest -rP -n auto .
env:
CO_API_KEY: ${{ secrets.COHERE_API_KEY }}
- name: Install aiohttp extra
run: poetry install --extras aiohttp
- name: Test (aiohttp)
run: poetry run pytest -rP -n auto -m aiohttp . || [ $? -eq 5 ]
env:
CO_API_KEY: ${{ secrets.COHERE_API_KEY }}
publish:
needs: [compile, test]
if: github.event_name == 'push' && contains(github.ref, 'refs/tags/')
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3
- name: Set up python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Bootstrap poetry
run: |
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
- name: Install dependencies
run: poetry install
- name: Publish to pypi
run: |
poetry config repositories.remote https://upload.pypi.org/legacy/
poetry --no-interaction -v publish --build --repository remote --username "$PYPI_USERNAME" --password "$PYPI_PASSWORD"
env:
PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }}
PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
================================================
FILE: .gitignore
================================================
.mypy_cache/
.ruff_cache/
__pycache__/
dist/
poetry.toml
================================================
FILE: 4.0.0-5.0.0-migration-guide.md
================================================
## `cohere==4` to `cohere==5` migration guide
As we migrate from the handwritten, manually-maintained sdk to our auto-generated sdk, there are some breaking changes that must be accommodated during migration. These should mostly improve the developer experience but thank you for bearing with us as we make these changes.
### Installation
To install the latest version of the cohere sdk `pip3 install --upgrade cohere`.
### Migrating usages
#### Migrating function calls
[This diff view](https://github.com/cohere-ai/cohere-python/compare/old-usage...new-usage) enumerates all usages of the old sdk and how they map to the new sdk. Some fields are no longer supported in the new sdk.
#### Migrating streaming usage
The `streaming: boolean` are no longer supported in the new sdk. Instead, you can replace the `chat` function with `chat_stream` and `generate` function with `generate_stream`. These will automatically inject the `streaming` parameter into the request. The following is an example usage for `chat_stream`:
```python
stream = co.chat_stream(
message="Tell me a short story"
)
for event in stream:
if event.event_type == "text-generation":
print(event.text, end='')
```
### Migrating deprecated `num_workers` Client constructor parameter
The Client constructor accepts an `httpx_client` which can be configured to limit the maximum number of connections.
```python
limits = httpx.Limits(max_connections=10)
cohere.Client(httpx_client=httpx.Client(limits=limits))
```
### Removed functionality (subject to change)
The following lists name the functions that are not in the new SDK and what their ongoing support status is.
#### No longer supported
* check_api_key
* loglikelihood
* batch_generate
* codebook
* batch_tokenize
* batch_detokenize
* detect_language
* generate_feedback
* generate_preference_feedback
* create_cluster_job
* get_cluster_job
* list_cluster_jobs
* wait_for_cluster_job
* create_custom_model
* wait_for_custom_model
* get_custom_model
* get_custom_model_by_name
* get_custom_model_metrics
* list_custom_models
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Cohere
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Cohere Python SDK

[](https://pypi.org/project/cohere/)

[](https://github.com/fern-api/fern)
The Cohere Python SDK allows access to Cohere models across many different platforms: the cohere platform, AWS (Bedrock, Sagemaker), Azure, GCP and Oracle OCI. For a full list of support and snippets, please take a look at the [SDK support docs page](https://docs.cohere.com/docs/cohere-works-everywhere).
## Documentation
Cohere documentation and API reference is available [here](https://docs.cohere.com/).
## Installation
```
pip install cohere
```
## Usage
```Python
import cohere
co = cohere.ClientV2()
response = co.chat(
model="command-r-plus-08-2024",
messages=[{"role": "user", "content": "hello world!"}],
)
print(response)
```
> [!TIP]
> You can set a system environment variable `CO_API_KEY` to avoid writing your api key within your code, e.g. add `export CO_API_KEY=theapikeyforyouraccount`
> in your ~/.zshrc or ~/.bashrc, open a new terminal, then code calling `cohere.Client()` will read this key.
## Streaming
The SDK supports streaming endpoints. To take advantage of this feature for chat,
use `chat_stream`.
```Python
import cohere
co = cohere.ClientV2()
response = co.chat_stream(
model="command-r-plus-08-2024",
messages=[{"role": "user", "content": "hello world!"}],
)
for event in response:
if event.type == "content-delta":
print(event.delta.message.content.text, end="")
```
## Oracle Cloud Infrastructure (OCI)
The SDK supports Oracle Cloud Infrastructure (OCI) Generative AI service. First, install the OCI SDK:
```
pip install 'cohere[oci]'
```
Then use the `OciClient` or `OciClientV2`:
```Python
import cohere
# Using OCI config file authentication (default: ~/.oci/config)
co = cohere.OciClient(
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
response = co.embed(
model="embed-english-v3.0",
texts=["Hello world"],
input_type="search_document",
)
print(response.embeddings)
```
### OCI Authentication Methods
**1. Config File (Default)**
```Python
co = cohere.OciClient(
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
# Uses ~/.oci/config with DEFAULT profile
)
```
**2. Custom Profile**
```Python
co = cohere.OciClient(
oci_profile="MY_PROFILE",
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
**3. Session-based Authentication (Security Token)**
```Python
# Works with OCI CLI session tokens
co = cohere.OciClient(
oci_profile="MY_SESSION_PROFILE", # Profile with security_token_file
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
**4. Direct Credentials**
```Python
co = cohere.OciClient(
oci_user_id="ocid1.user.oc1...",
oci_fingerprint="xx:xx:xx:...",
oci_tenancy_id="ocid1.tenancy.oc1...",
oci_private_key_path="~/.oci/key.pem",
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
**5. Instance Principal (for OCI Compute instances)**
```Python
co = cohere.OciClient(
auth_type="instance_principal",
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
### Supported OCI APIs
The OCI client supports the following Cohere APIs:
- **Embed**: Full support for all embedding models
- **Chat**: Full support with both V1 (`OciClient`) and V2 (`OciClientV2`) APIs
- Streaming available via `chat_stream()`
- Supports Command-R and Command-A model families
### OCI Model Availability and Limitations
**Available on OCI On-Demand Inference:**
- ✅ **Embed models**: available on OCI Generative AI
- ✅ **Chat models**: available via `OciClient` (V1) and `OciClientV2` (V2)
**Not Available on OCI On-Demand Inference:**
- ❌ **Generate API**: OCI TEXT_GENERATION models are base models that require fine-tuning before deployment
- ❌ **Rerank API**: OCI TEXT_RERANK models are base models that require fine-tuning before deployment
- ❌ **Multiple Embedding Types**: OCI on-demand models only support single embedding type per request (cannot request both `float` and `int8` simultaneously)
**Note**: To use Generate or Rerank models on OCI, you need to:
1. Fine-tune the base model using OCI's fine-tuning service
2. Deploy the fine-tuned model to a dedicated endpoint
3. Update your code to use the deployed model endpoint
For the latest model availability, see the [OCI Generative AI documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm).
## Contributing
While we value open-source contributions to this SDK, the code is generated programmatically. Additions made directly would have to be moved over to our generation code, otherwise they would be overwritten upon the next generated release. Feel free to open a PR as a proof of concept, but know that we will not be able to merge it as-is. We suggest opening an issue first to discuss with us!
On the other hand, contributions to the README are always very welcome!
================================================
FILE: mypy.ini
================================================
[mypy]
exclude = src/cohere/manually_maintained/cohere_aws
================================================
FILE: pyproject.toml
================================================
[project]
name = "cohere"
dynamic = ["version"]
[tool.poetry]
name = "cohere"
version = "6.1.0"
description = ""
readme = "README.md"
authors = []
keywords = []
license = "MIT"
classifiers = [
"Intended Audience :: Developers",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Operating System :: OS Independent",
"Operating System :: POSIX",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
"License :: OSI Approved :: MIT License"
]
packages = [
{ include = "cohere", from = "src"}
]
[tool.poetry.urls]
Repository = 'https://github.com/cohere-ai/cohere-python'
[tool.poetry.dependencies]
python = "^3.10"
aiohttp = { version = ">=3.10.0,<4", optional = true}
fastavro = "^1.9.4"
httpx = ">=0.21.2"
httpx-aiohttp = { version = "0.1.8", optional = true}
oci = { version = "^2.165.0", optional = true}
pydantic = ">= 1.9.2"
pydantic-core = ">=2.18.2,<2.44.0"
requests = "^2.0.0"
tokenizers = ">=0.15,<1"
types-requests = "^2.0.0"
typing_extensions = ">= 4.0.0"
[tool.poetry.group.dev.dependencies]
mypy = "==1.13.0"
pytest = "^8.2.0"
pytest-asyncio = "^1.0.0"
pytest-xdist = "^3.6.1"
python-dateutil = "^2.9.0"
types-python-dateutil = "^2.9.0.20240316"
ruff = "==0.11.5"
[tool.pytest.ini_options]
testpaths = [ "tests" ]
asyncio_mode = "auto"
markers = [
"aiohttp: tests that require httpx_aiohttp to be installed",
]
[tool.mypy]
plugins = ["pydantic.mypy"]
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
]
ignore = [
"E402", # Module level import not at top of file
"E501", # Line too long
"E711", # Comparison to `None` should be `cond is not None`
"E712", # Avoid equality comparisons to `True`; use `if ...:` checks
"E721", # Use `is` and `is not` for type comparisons, or `isinstance()` for insinstance checks
"E722", # Do not use bare `except`
"E731", # Do not assign a `lambda` expression, use a `def`
"F821", # Undefined name
"F841" # Local variable ... is assigned to but never used
]
[tool.ruff.lint.isort]
section-order = ["future", "standard-library", "third-party", "first-party"]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.extras]
oci=["oci"]
aiohttp=["aiohttp", "httpx-aiohttp"]
================================================
FILE: reference.md
================================================
# Reference
client.chat_stream(...) -> typing.Iterator[bytes]
-
#### 📝 Description
-
-
Generates a streamed text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.chat_stream(
model="command-a-03-2025",
message="hello!",
)
```
#### ⚙️ Parameters
-
-
**message:** `str`
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**stream:** `typing.Literal`
Defaults to `false`.
When `true`, the response will be a JSON stream of events. The final event will contain the complete response, and will have an `event_type` of `"stream-end"`.
Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**accepts:** `typing.Optional[typing.Literal]` — Pass text/event-stream to receive the streamed response as server-sent events. The default is `\n` delimited events.
-
**model:** `typing.Optional[str]`
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
-
**preamble:** `typing.Optional[str]`
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**chat_history:** `typing.Optional[typing.List[Message]]`
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**conversation_id:** `typing.Optional[str]`
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
-
**prompt_truncation:** `typing.Optional[ChatStreamRequestPromptTruncation]`
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**connectors:** `typing.Optional[typing.List[ChatConnector]]`
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
-
**search_queries_only:** `typing.Optional[bool]`
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**documents:** `typing.Optional[typing.List[ChatDocument]]`
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**citation_quality:** `typing.Optional[ChatStreamRequestCitationQuality]`
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**temperature:** `typing.Optional[float]`
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**max_tokens:** `typing.Optional[int]`
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**max_input_tokens:** `typing.Optional[int]`
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
-
**k:** `typing.Optional[int]`
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**stop_sequences:** `typing.Optional[typing.List[str]]`
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**frequency_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**raw_prompting:** `typing.Optional[bool]`
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**tools:** `typing.Optional[typing.List[Tool]]`
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**tool_results:** `typing.Optional[typing.List[ToolResult]]`
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**force_single_step:** `typing.Optional[bool]` — Forces the chat to be single step. Defaults to `false`.
-
**response_format:** `typing.Optional[ResponseFormat]`
-
**safety_mode:** `typing.Optional[ChatStreamRequestSafetyMode]`
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.chat(...) -> NonStreamedChatResponse
-
#### 📝 Description
-
-
Generates a text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.chat_stream(
model="command-a-03-2025",
message="Tell me about LLMs",
)
```
#### ⚙️ Parameters
-
-
**message:** `str`
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**stream:** `typing.Literal`
Defaults to `false`.
When `true`, the response will be a JSON stream of events. The final event will contain the complete response, and will have an `event_type` of `"stream-end"`.
Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**accepts:** `typing.Optional[typing.Literal]` — Pass text/event-stream to receive the streamed response as server-sent events. The default is `\n` delimited events.
-
**model:** `typing.Optional[str]`
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
-
**preamble:** `typing.Optional[str]`
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**chat_history:** `typing.Optional[typing.List[Message]]`
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**conversation_id:** `typing.Optional[str]`
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
-
**prompt_truncation:** `typing.Optional[ChatRequestPromptTruncation]`
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**connectors:** `typing.Optional[typing.List[ChatConnector]]`
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
-
**search_queries_only:** `typing.Optional[bool]`
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**documents:** `typing.Optional[typing.List[ChatDocument]]`
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**citation_quality:** `typing.Optional[ChatRequestCitationQuality]`
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**temperature:** `typing.Optional[float]`
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**max_tokens:** `typing.Optional[int]`
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**max_input_tokens:** `typing.Optional[int]`
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
-
**k:** `typing.Optional[int]`
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**stop_sequences:** `typing.Optional[typing.List[str]]`
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**frequency_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**raw_prompting:** `typing.Optional[bool]`
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**tools:** `typing.Optional[typing.List[Tool]]`
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**tool_results:** `typing.Optional[typing.List[ToolResult]]`
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**force_single_step:** `typing.Optional[bool]` — Forces the chat to be single step. Defaults to `false`.
-
**response_format:** `typing.Optional[ResponseFormat]`
-
**safety_mode:** `typing.Optional[ChatRequestSafetyMode]`
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.generate_stream(...) -> typing.Iterator[bytes]
-
#### 📝 Description
-
-
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API.
Generates realistic text conditioned on a given input.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.generate_stream(
prompt="Please explain to me how LLMs work",
)
```
#### ⚙️ Parameters
-
-
**prompt:** `str`
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
-
**stream:** `typing.Literal`
When `true`, the response will be a JSON stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
The final event will contain the complete response, and will contain an `is_finished` field set to `true`. The event will also contain a `finish_reason`, which can be one of the following:
- `COMPLETE` - the model sent back a finished reply
- `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens for its context length
- `ERROR` - something went wrong when generating the reply
- `ERROR_TOXIC` - the model generated a reply that was deemed toxic
-
**model:** `typing.Optional[str]`
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
-
**num_generations:** `typing.Optional[int]` — The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
-
**max_tokens:** `typing.Optional[int]`
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
-
**truncate:** `typing.Optional[GenerateStreamRequestTruncate]`
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
-
**temperature:** `typing.Optional[float]`
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**preset:** `typing.Optional[str]`
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
-
**end_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
-
**stop_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
-
**k:** `typing.Optional[int]`
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
**frequency_penalty:** `typing.Optional[float]`
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
-
**return_likelihoods:** `typing.Optional[GenerateStreamRequestReturnLikelihoods]`
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
-
**raw_prompting:** `typing.Optional[bool]` — When enabled, the user's prompt will be sent to the model without any pre-processing.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.generate(...) -> Generation
-
#### 📝 Description
-
-
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates realistic text conditioned on a given input.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.generate_stream(
prompt="Please explain to me how LLMs work",
)
```
#### ⚙️ Parameters
-
-
**prompt:** `str`
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
-
**stream:** `typing.Literal`
When `true`, the response will be a JSON stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
The final event will contain the complete response, and will contain an `is_finished` field set to `true`. The event will also contain a `finish_reason`, which can be one of the following:
- `COMPLETE` - the model sent back a finished reply
- `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens for its context length
- `ERROR` - something went wrong when generating the reply
- `ERROR_TOXIC` - the model generated a reply that was deemed toxic
-
**model:** `typing.Optional[str]`
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
-
**num_generations:** `typing.Optional[int]` — The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
-
**max_tokens:** `typing.Optional[int]`
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
-
**truncate:** `typing.Optional[GenerateRequestTruncate]`
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
-
**temperature:** `typing.Optional[float]`
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
-
**preset:** `typing.Optional[str]`
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
-
**end_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
-
**stop_sequences:** `typing.Optional[typing.List[str]]` — The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
-
**k:** `typing.Optional[int]`
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
**frequency_penalty:** `typing.Optional[float]`
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
-
**return_likelihoods:** `typing.Optional[GenerateRequestReturnLikelihoods]`
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
-
**raw_prompting:** `typing.Optional[bool]` — When enabled, the user's prompt will be sent to the model without any pre-processing.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed(...) -> EmbedResponse
-
#### 📝 Description
-
-
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents.
Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.embed(
texts=[
"hello",
"goodbye"
],
model="embed-v4.0",
input_type="classification",
)
```
#### ⚙️ Parameters
-
-
**texts:** `typing.Optional[typing.List[str]]` — An array of strings for the model to embed. Maximum number of texts per call is `96`.
-
**images:** `typing.Optional[typing.List[str]]`
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Images are only supported with Embed v3.0 and newer models.
-
**model:** `typing.Optional[str]` — ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
-
**input_type:** `typing.Optional[EmbedInputType]`
-
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]`
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
-
**truncate:** `typing.Optional[EmbedRequestTruncate]`
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.rerank(...) -> RerankResponse
-
#### 📝 Description
-
-
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.rerank(
documents=[
{
"text": "Carson City is the capital city of the American state of Nevada."
},
{
"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
},
{
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
{
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
{
"text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
}
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
```
#### ⚙️ Parameters
-
-
**query:** `str` — The search query
-
**documents:** `typing.List[RerankRequestDocumentsItem]`
A list of document objects or strings to rerank.
If a document is provided the text fields is required and all other fields will be preserved in the response.
The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
We recommend a maximum of 1,000 documents for optimal endpoint performance.
-
**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
-
**top_n:** `typing.Optional[int]` — The number of most relevant documents or indices to return, defaults to the length of the documents
-
**rank_fields:** `typing.Optional[typing.List[str]]` — If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
-
**return_documents:** `typing.Optional[bool]`
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
-
**max_chunks_per_doc:** `typing.Optional[int]` — The maximum number of chunks to produce internally from a document
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.classify(...) -> ClassifyResponse
-
#### 📝 Description
-
-
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference.
Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
#### 🔌 Usage
-
-
```python
from cohere import Client, ClassifyExample
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.classify(
examples=[
ClassifyExample(
text="Dermatologists don\'t like her!",
label="Spam",
),
ClassifyExample(
text="\'Hello, open to this?\'",
label="Spam",
),
ClassifyExample(
text="I need help please wire me $1000 right now",
label="Spam",
),
ClassifyExample(
text="Nice to know you ;)",
label="Spam",
),
ClassifyExample(
text="Please help me?",
label="Spam",
),
ClassifyExample(
text="Your parcel will be delivered today",
label="Not spam",
),
ClassifyExample(
text="Review changes to our Terms and Conditions",
label="Not spam",
),
ClassifyExample(
text="Weekly sync notes",
label="Not spam",
),
ClassifyExample(
text="\'Re: Follow up from today\'s meeting\'",
label="Not spam",
),
ClassifyExample(
text="Pre-read for tomorrow",
label="Not spam",
)
],
inputs=[
"Confirm your email address",
"hey i need u to send some $"
],
model="YOUR-FINE-TUNED-MODEL-ID",
)
```
#### ⚙️ Parameters
-
-
**inputs:** `typing.List[str]`
A list of up to 96 texts to be classified. Each one must be a non-empty string.
There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models).
Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
-
**examples:** `typing.Optional[typing.List[ClassifyExample]]`
An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`.
Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
-
**model:** `typing.Optional[str]` — ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
-
**preset:** `typing.Optional[str]` — The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
-
**truncate:** `typing.Optional[ClassifyRequestTruncate]`
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.summarize(...) -> SummarizeResponse
-
#### 📝 Description
-
-
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates a summary in English for a given text.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.summarize(
text="Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name \"ice cream\" varies from one country to another. In some countries, such as the United States, \"ice cream\" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled \"frozen dairy dessert\" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.",
)
```
#### ⚙️ Parameters
-
-
**text:** `str` — The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
-
**length:** `typing.Optional[SummarizeRequestLength]` — One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
-
**format:** `typing.Optional[SummarizeRequestFormat]` — One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
-
**model:** `typing.Optional[str]` — The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
-
**extractiveness:** `typing.Optional[SummarizeRequestExtractiveness]` — One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
-
**temperature:** `typing.Optional[float]` — Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
-
**additional_command:** `typing.Optional[str]` — A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.tokenize(...) -> TokenizeResponse
-
#### 📝 Description
-
-
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.tokenize(
text="tokenize me! :D",
model="command",
)
```
#### ⚙️ Parameters
-
-
**text:** `str` — The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
-
**model:** `str` — The input will be tokenized by the tokenizer that is used by this model.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.detokenize(...) -> DetokenizeResponse
-
#### 📝 Description
-
-
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.detokenize(
tokens=[
10002,
2261,
2012,
8,
2792,
43
],
model="command",
)
```
#### ⚙️ Parameters
-
-
**tokens:** `typing.List[int]` — The list of tokens to be detokenized.
-
**model:** `str` — An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.check_api_key() -> CheckApiKeyResponse
-
#### 📝 Description
-
-
Checks that the api key in the Authorization header is valid and active
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.check_api_key()
```
#### ⚙️ Parameters
-
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## V2
client.v2.chat_stream(...) -> typing.Iterator[bytes]
-
#### 📝 Description
-
-
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
#### 🔌 Usage
-
-
```python
from cohere import Client, ChatMessageV2_User
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.v2.chat_stream(
model="command-a-03-2025",
messages=[
ChatMessageV2_User(
content="Tell me about LLMs",
)
],
)
```
#### ⚙️ Parameters
-
-
**stream:** `typing.Literal`
Defaults to `false`.
When `true`, the response will be a SSE stream of events.
Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
-
**model:** `str` — The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
-
**messages:** `ChatMessages`
-
**tools:** `typing.Optional[typing.List[ToolV2]]`
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
-
**strict_tools:** `typing.Optional[bool]`
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
-
**documents:** `typing.Optional[typing.List[V2ChatStreamRequestDocumentsItem]]` — A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
-
**citation_options:** `typing.Optional[CitationOptions]`
-
**response_format:** `typing.Optional[ResponseFormatV2]`
-
**safety_mode:** `typing.Optional[V2ChatStreamRequestSafetyMode]`
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
-
**max_tokens:** `typing.Optional[int]`
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
-
**stop_sequences:** `typing.Optional[typing.List[str]]` — A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
-
**temperature:** `typing.Optional[float]`
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
-
**frequency_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
-
**k:** `typing.Optional[int]`
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
**logprobs:** `typing.Optional[bool]` — Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
-
**tool_choice:** `typing.Optional[V2ChatStreamRequestToolChoice]`
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
-
**thinking:** `typing.Optional[Thinking]`
-
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.chat(...) -> V2ChatResponse
-
#### 📝 Description
-
-
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
#### 🔌 Usage
-
-
```python
from cohere import Client, ChatMessageV2_User
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.v2.chat_stream(
model="command-a-03-2025",
messages=[
ChatMessageV2_User(
content="Tell me about LLMs",
)
],
)
```
#### ⚙️ Parameters
-
-
**stream:** `typing.Literal`
Defaults to `false`.
When `true`, the response will be a SSE stream of events.
Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
-
**model:** `str` — The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
-
**messages:** `ChatMessages`
-
**tools:** `typing.Optional[typing.List[ToolV2]]`
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
-
**strict_tools:** `typing.Optional[bool]`
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
-
**documents:** `typing.Optional[typing.List[V2ChatRequestDocumentsItem]]` — A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
-
**citation_options:** `typing.Optional[CitationOptions]`
-
**response_format:** `typing.Optional[ResponseFormatV2]`
-
**safety_mode:** `typing.Optional[V2ChatRequestSafetyMode]`
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
-
**max_tokens:** `typing.Optional[int]`
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
-
**stop_sequences:** `typing.Optional[typing.List[str]]` — A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
-
**temperature:** `typing.Optional[float]`
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
-
**seed:** `typing.Optional[int]`
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
-
**frequency_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
-
**presence_penalty:** `typing.Optional[float]`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
-
**k:** `typing.Optional[int]`
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
-
**p:** `typing.Optional[float]`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
**logprobs:** `typing.Optional[bool]` — Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
-
**tool_choice:** `typing.Optional[V2ChatRequestToolChoice]`
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
-
**thinking:** `typing.Optional[Thinking]`
-
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.embed(...) -> EmbedByTypeResponse
-
#### 📝 Description
-
-
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.v2.embed(
texts=[
"hello",
"goodbye"
],
model="embed-v4.0",
input_type="classification",
embedding_types=[
"float"
],
)
```
#### ⚙️ Parameters
-
-
**model:** `str` — ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
-
**input_type:** `EmbedInputType`
-
**texts:** `typing.Optional[typing.List[str]]` — An array of strings for the model to embed. Maximum number of texts per call is `96`.
-
**images:** `typing.Optional[typing.List[str]]`
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Image embeddings are supported with Embed v3.0 and newer models.
-
**inputs:** `typing.Optional[typing.List[EmbedInput]]` — An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
-
**max_tokens:** `typing.Optional[int]` — The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
-
**output_dimension:** `typing.Optional[int]`
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
-
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]`
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
-
**truncate:** `typing.Optional[V2EmbedRequestTruncate]`
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
-
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.v2.rerank(...) -> V2RerankResponse
-
#### 📝 Description
-
-
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.v2.rerank(
documents=[
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
```
#### ⚙️ Parameters
-
-
**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
-
**query:** `str` — The search query
-
**documents:** `typing.List[str]`
A list of texts that will be compared to the `query`.
For optimal performance we recommend against sending more than 1,000 documents in a single request.
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
**Note**: structured data should be formatted as YAML strings for best performance.
-
**top_n:** `typing.Optional[int]` — Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
-
**max_tokens_per_doc:** `typing.Optional[int]` — Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
-
**priority:** `typing.Optional[int]` — Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Batches
client.batches.list(...) -> ListBatchesResponse
-
#### 📝 Description
-
-
List the batches for the current user
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.batches.list(
page_size=1,
page_token="page_token",
order_by="order_by",
)
```
#### ⚙️ Parameters
-
-
**page_size:** `typing.Optional[int]`
The maximum number of batches to return. The service may return fewer than
this value.
If unspecified, at most 50 batches will be returned.
The maximum value is 1000; values above 1000 will be coerced to 1000.
-
**page_token:** `typing.Optional[str]`
A page token, received from a previous `ListBatches` call.
Provide this to retrieve the subsequent page.
-
**order_by:** `typing.Optional[str]`
Batches can be ordered by creation time or last updated time.
Use `created_at` for creation time or `updated_at` for last updated time.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.create(...) -> CreateBatchResponse
-
#### 📝 Description
-
-
Creates and executes a batch from an uploaded dataset of requests
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
from cohere.batches import Batch
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.batches.create(
request=Batch(
name="name",
input_dataset_id="input_dataset_id",
model="model",
),
)
```
#### ⚙️ Parameters
-
-
**request:** `Batch`
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.retrieve(...) -> GetBatchResponse
-
#### 📝 Description
-
-
Retrieves a batch
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.batches.retrieve(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The batch ID.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.batches.cancel(...) -> CancelBatchResponse
-
#### 📝 Description
-
-
Cancels an in-progress batch
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.batches.cancel(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The batch ID.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## EmbedJobs
client.embed_jobs.list() -> ListEmbedJobResponse
-
#### 📝 Description
-
-
The list embed job endpoint allows users to view all embed jobs history for that specific user.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.embed_jobs.list()
```
#### ⚙️ Parameters
-
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.create(...) -> CreateEmbedJobResponse
-
#### 📝 Description
-
-
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.embed_jobs.create(
model="model",
dataset_id="dataset_id",
input_type="search_document",
)
```
#### ⚙️ Parameters
-
-
**model:** `str`
ID of the embedding model.
Available models and corresponding embedding dimensions:
- `embed-english-v3.0` : 1024
- `embed-multilingual-v3.0` : 1024
- `embed-english-light-v3.0` : 384
- `embed-multilingual-light-v3.0` : 384
-
**dataset_id:** `str` — ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
-
**input_type:** `EmbedInputType`
-
**name:** `typing.Optional[str]` — The name of the embed job.
-
**embedding_types:** `typing.Optional[typing.List[EmbeddingType]]`
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions.
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
-
**truncate:** `typing.Optional[CreateEmbedJobRequestTruncate]`
One of `START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.get(...) -> EmbedJob
-
#### 📝 Description
-
-
This API retrieves the details about an embed job started by the same user.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.embed_jobs.get(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the embed job to retrieve.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.embed_jobs.cancel(...)
-
#### 📝 Description
-
-
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.embed_jobs.cancel(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the embed job to cancel.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Datasets
client.datasets.list(...) -> DatasetsListResponse
-
#### 📝 Description
-
-
List datasets that have been created.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
import datetime
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.datasets.list(
dataset_type="datasetType",
before=datetime.datetime.fromisoformat("2024-01-15T09:30:00+00:00"),
after=datetime.datetime.fromisoformat("2024-01-15T09:30:00+00:00"),
limit=1.1,
offset=1.1,
validation_status="unknown",
)
```
#### ⚙️ Parameters
-
-
**dataset_type:** `typing.Optional[str]` — optional filter by dataset type
-
**before:** `typing.Optional[datetime.datetime]` — optional filter before a date
-
**after:** `typing.Optional[datetime.datetime]` — optional filter after a date
-
**limit:** `typing.Optional[float]` — optional limit to number of results
-
**offset:** `typing.Optional[float]` — optional offset to start of results
-
**validation_status:** `typing.Optional[DatasetValidationStatus]` — optional filter by validation status
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.create(...) -> DatasetsCreateResponse
-
#### 📝 Description
-
-
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.datasets.create(
name="name",
type="embed-input",
keep_original_file=True,
skip_malformed_input=True,
text_separator="text_separator",
csv_delimiter="csv_delimiter",
data="example_data",
eval_data="example_eval_data",
)
```
#### ⚙️ Parameters
-
-
**name:** `str` — The name of the uploaded dataset.
-
**type:** `DatasetType` — The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
-
**data:** `core.File` — The file to upload
-
**keep_original_file:** `typing.Optional[bool]` — Indicates if the original file should be stored.
-
**skip_malformed_input:** `typing.Optional[bool]` — Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
-
**keep_fields:** `typing.Optional[typing.Union[str, typing.Sequence[str]]]` — List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
-
**optional_fields:** `typing.Optional[typing.Union[str, typing.Sequence[str]]]` — List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
-
**text_separator:** `typing.Optional[str]` — Raw .txt uploads will be split into entries using the text_separator value.
-
**csv_delimiter:** `typing.Optional[str]` — The delimiter used for .csv uploads.
-
**eval_data:** `typing.Optional[core.File]` — An optional evaluation file to upload
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.get_usage() -> DatasetsGetUsageResponse
-
#### 📝 Description
-
-
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.datasets.get_usage()
```
#### ⚙️ Parameters
-
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.get(...) -> DatasetsGetResponse
-
#### 📝 Description
-
-
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.datasets.get(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str`
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.datasets.delete(...) -> typing.Dict[str, typing.Any]
-
#### 📝 Description
-
-
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.datasets.delete(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str`
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Connectors
client.connectors.list(...) -> ListConnectorsResponse
-
#### 📝 Description
-
-
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.list(
limit=1.1,
offset=1.1,
)
```
#### ⚙️ Parameters
-
-
**limit:** `typing.Optional[float]` — Maximum number of connectors to return [0, 100].
-
**offset:** `typing.Optional[float]` — Number of connectors to skip before returning results [0, inf].
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.create(...) -> CreateConnectorResponse
-
#### 📝 Description
-
-
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.create(
name="name",
url="url",
)
```
#### ⚙️ Parameters
-
-
**name:** `str` — A human-readable name for the connector.
-
**url:** `str` — The URL of the connector that will be used to search for documents.
-
**description:** `typing.Optional[str]` — A description of the connector.
-
**excludes:** `typing.Optional[typing.List[str]]` — A list of fields to exclude from the prompt (fields remain in the document).
-
**oauth:** `typing.Optional[CreateConnectorOAuth]` — The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
-
**active:** `typing.Optional[bool]` — Whether the connector is active or not.
-
**continue_on_failure:** `typing.Optional[bool]` — Whether a chat request should continue or not if the request to this connector fails.
-
**service_auth:** `typing.Optional[CreateConnectorServiceAuth]` — The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.get(...) -> GetConnectorResponse
-
#### 📝 Description
-
-
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.get(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the connector to retrieve.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.delete(...) -> DeleteConnectorResponse
-
#### 📝 Description
-
-
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.delete(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the connector to delete.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.update(...) -> UpdateConnectorResponse
-
#### 📝 Description
-
-
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.update(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the connector to update.
-
**name:** `typing.Optional[str]` — A human-readable name for the connector.
-
**url:** `typing.Optional[str]` — The URL of the connector that will be used to search for documents.
-
**excludes:** `typing.Optional[typing.List[str]]` — A list of fields to exclude from the prompt (fields remain in the document).
-
**oauth:** `typing.Optional[CreateConnectorOAuth]` — The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
-
**active:** `typing.Optional[bool]`
-
**continue_on_failure:** `typing.Optional[bool]`
-
**service_auth:** `typing.Optional[CreateConnectorServiceAuth]` — The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.connectors.o_auth_authorize(...) -> OAuthAuthorizeResponse
-
#### 📝 Description
-
-
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.connectors.o_auth_authorize(
id="id",
after_token_redirect="after_token_redirect",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The ID of the connector to authorize.
-
**after_token_redirect:** `typing.Optional[str]` — The URL to redirect to after the connector has been authorized.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Models
client.models.get(...) -> GetModelResponse
-
#### 📝 Description
-
-
Returns the details of a model, provided its name.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.models.get(
model="command-a-03-2025",
)
```
#### ⚙️ Parameters
-
-
**model:** `str`
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.models.list(...) -> ListModelsResponse
-
#### 📝 Description
-
-
Returns a list of models available for use.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.models.list(
page_size=1.1,
page_token="page_token",
endpoint="chat",
default_only=True,
)
```
#### ⚙️ Parameters
-
-
**page_size:** `typing.Optional[float]`
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
-
**page_token:** `typing.Optional[str]` — Page token provided in the `next_page_token` field of a previous response.
-
**endpoint:** `typing.Optional[CompatibleEndpoint]` — When provided, filters the list of models to only those that are compatible with the specified endpoint.
-
**default_only:** `typing.Optional[bool]` — When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## /finetuning
client.finetuning.list_finetuned_models(...) -> ListFinetunedModelsResponse
-
#### 📝 Description
-
-
Returns a list of fine-tuned models that the user has access to.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.list_finetuned_models(
page_size=1,
page_token="page_token",
order_by="order_by",
)
```
#### ⚙️ Parameters
-
-
**page_size:** `typing.Optional[int]`
Maximum number of results to be returned by the server. If 0, defaults to
50.
-
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
-
**order_by:** `typing.Optional[str]`
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.create_finetuned_model(...) -> CreateFinetunedModelResponse
-
#### 📝 Description
-
-
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
from cohere.finetuning.finetuning import FinetunedModel, Settings, BaseModel
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.create_finetuned_model(
request=FinetunedModel(
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
),
)
```
#### ⚙️ Parameters
-
-
**request:** `FinetunedModel`
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.get_finetuned_model(...) -> GetFinetunedModelResponse
-
#### 📝 Description
-
-
Retrieve a fine-tuned model by its ID.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.get_finetuned_model(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The fine-tuned model ID.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.delete_finetuned_model(...) -> DeleteFinetunedModelResponse
-
#### 📝 Description
-
-
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use.
This operation is irreversible.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.delete_finetuned_model(
id="id",
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — The fine-tuned model ID.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.update_finetuned_model(...) -> UpdateFinetunedModelResponse
-
#### 📝 Description
-
-
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
from cohere.finetuning.finetuning import Settings, BaseModel
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.update_finetuned_model(
id="id",
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
)
```
#### ⚙️ Parameters
-
-
**id:** `str` — FinetunedModel ID.
-
**name:** `str` — FinetunedModel name (e.g. `foobar`).
-
**settings:** `Settings` — FinetunedModel settings such as dataset, hyperparameters...
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.list_events(...) -> ListEventsResponse
-
#### 📝 Description
-
-
Returns a list of events that occurred during the life-cycle of the fine-tuned model.
The events are ordered by creation time, with the most recent event first.
The list can be paginated using `page_size` and `page_token` parameters.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.list_events(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
order_by="order_by",
)
```
#### ⚙️ Parameters
-
-
**finetuned_model_id:** `str` — The parent fine-tuned model ID.
-
**page_size:** `typing.Optional[int]`
Maximum number of results to be returned by the server. If 0, defaults to
50.
-
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
-
**order_by:** `typing.Optional[str]`
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
client.finetuning.list_training_step_metrics(...) -> ListTrainingStepMetricsResponse
-
#### 📝 Description
-
-
Returns a list of metrics measured during the training of a fine-tuned model.
The metrics are ordered by step number, with the most recent step first.
The list can be paginated using `page_size` and `page_token` parameters.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.finetuning.list_training_step_metrics(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
)
```
#### ⚙️ Parameters
-
-
**finetuned_model_id:** `str` — The parent fine-tuned model ID.
-
**page_size:** `typing.Optional[int]`
Maximum number of results to be returned by the server. If 0, defaults to
50.
-
**page_token:** `typing.Optional[str]` — Request a specific page of the list results.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
## Audio Transcriptions
client.audio.transcriptions.create(...) -> AudioTranscriptionsCreateResponse
-
#### 📝 Description
-
-
Transcribe an audio file.
#### 🔌 Usage
-
-
```python
from cohere import Client
from cohere.environment import ClientEnvironment
client = Client(
token="",
environment=ClientEnvironment.PRODUCTION,
)
client.audio.transcriptions.create(
file="example_file",
model="model",
language="language",
)
```
#### ⚙️ Parameters
-
-
**model:** `str` — ID of the model to use.
-
**language:** `str` — The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
-
**file:** `core.File` — The audio file object to transcribe. Supported file extensions are flac, mp3, mpeg, mpga, ogg, and wav.
-
**temperature:** `typing.Optional[float]` — The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
-
**request_options:** `typing.Optional[RequestOptions]` — Request-specific configuration.
================================================
FILE: requirements.txt
================================================
fastavro==1.9.4
httpx>=0.21.2
pydantic>= 1.9.2
pydantic-core>=2.18.2,<2.44.0
requests==2.0.0
tokenizers>=0.15,<1
types-requests==2.0.0
typing_extensions>= 4.0.0
================================================
FILE: src/cohere/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import (
ApiMeta,
ApiMetaApiVersion,
ApiMetaBilledUnits,
ApiMetaTokens,
AssistantChatMessageV2,
AssistantMessage,
AssistantMessageResponse,
AssistantMessageResponseContentItem,
AssistantMessageV2Content,
AssistantMessageV2ContentOneItem,
AuthTokenType,
ChatCitation,
ChatCitationGenerationEvent,
ChatCitationType,
ChatConnector,
ChatContentDeltaEvent,
ChatContentDeltaEventDelta,
ChatContentDeltaEventDeltaMessage,
ChatContentDeltaEventDeltaMessageContent,
ChatContentEndEvent,
ChatContentStartEvent,
ChatContentStartEventDelta,
ChatContentStartEventDeltaMessage,
ChatContentStartEventDeltaMessageContent,
ChatContentStartEventDeltaMessageContentType,
ChatDataMetrics,
ChatDebugEvent,
ChatDocument,
ChatDocumentSource,
ChatFinishReason,
ChatMessage,
ChatMessageEndEvent,
ChatMessageEndEventDelta,
ChatMessageStartEvent,
ChatMessageStartEventDelta,
ChatMessageStartEventDeltaMessage,
ChatMessageV2,
ChatMessages,
ChatRequestCitationQuality,
ChatRequestPromptTruncation,
ChatRequestSafetyMode,
ChatSearchQueriesGenerationEvent,
ChatSearchQuery,
ChatSearchResult,
ChatSearchResultConnector,
ChatSearchResultsEvent,
ChatStreamEndEvent,
ChatStreamEndEventFinishReason,
ChatStreamEvent,
ChatStreamEventType,
ChatStreamRequestCitationQuality,
ChatStreamRequestPromptTruncation,
ChatStreamRequestSafetyMode,
ChatStreamStartEvent,
ChatTextContent,
ChatTextGenerationEvent,
ChatTextResponseFormat,
ChatTextResponseFormatV2,
ChatThinkingContent,
ChatToolCallDeltaEvent,
ChatToolCallDeltaEventDelta,
ChatToolCallDeltaEventDeltaMessage,
ChatToolCallDeltaEventDeltaMessageToolCalls,
ChatToolCallDeltaEventDeltaMessageToolCallsFunction,
ChatToolCallEndEvent,
ChatToolCallStartEvent,
ChatToolCallStartEventDelta,
ChatToolCallStartEventDeltaMessage,
ChatToolCallsChunkEvent,
ChatToolCallsGenerationEvent,
ChatToolMessage,
ChatToolPlanDeltaEvent,
ChatToolPlanDeltaEventDelta,
ChatToolPlanDeltaEventDeltaMessage,
ChatToolSource,
ChatbotMessage,
CheckApiKeyResponse,
Citation,
CitationEndEvent,
CitationGenerationStreamedChatResponse,
CitationOptions,
CitationOptionsMode,
CitationStartEvent,
CitationStartEventDelta,
CitationStartEventDeltaMessage,
CitationType,
ClassifyDataMetrics,
ClassifyExample,
ClassifyRequestTruncate,
ClassifyResponse,
ClassifyResponseClassificationsItem,
ClassifyResponseClassificationsItemClassificationType,
ClassifyResponseClassificationsItemLabelsValue,
CompatibleEndpoint,
Connector,
ConnectorAuthStatus,
ConnectorOAuth,
Content,
CreateConnectorOAuth,
CreateConnectorResponse,
CreateConnectorServiceAuth,
CreateEmbedJobResponse,
Dataset,
DatasetPart,
DatasetType,
DatasetValidationStatus,
DebugStreamedChatResponse,
DeleteConnectorResponse,
DetokenizeResponse,
Document,
DocumentContent,
DocumentSource,
DocumentToolContent,
EmbedByTypeResponse,
EmbedByTypeResponseEmbeddings,
EmbedByTypeResponseResponseType,
EmbedContent,
EmbedFloatsResponse,
EmbedImage,
EmbedImageUrl,
EmbedInput,
EmbedInputType,
EmbedJob,
EmbedJobStatus,
EmbedJobTruncate,
EmbedRequestTruncate,
EmbedResponse,
EmbedText,
EmbeddingType,
EmbeddingsByTypeEmbedResponse,
EmbeddingsFloatsEmbedResponse,
FinetuneDatasetMetrics,
FinishReason,
GenerateRequestReturnLikelihoods,
GenerateRequestTruncate,
GenerateStreamEnd,
GenerateStreamEndResponse,
GenerateStreamError,
GenerateStreamEvent,
GenerateStreamRequestReturnLikelihoods,
GenerateStreamRequestTruncate,
GenerateStreamText,
GenerateStreamedResponse,
Generation,
GetConnectorResponse,
GetModelResponse,
GetModelResponseSamplingDefaults,
Image,
ImageContent,
ImageUrl,
ImageUrlContent,
ImageUrlDetail,
ImageUrlEmbedContent,
JsonObjectResponseFormat,
JsonObjectResponseFormatV2,
JsonResponseFormat,
JsonResponseFormatV2,
LabelMetric,
ListConnectorsResponse,
ListEmbedJobResponse,
ListModelsResponse,
LogprobItem,
Message,
Metrics,
NonStreamedChatResponse,
OAuthAuthorizeResponse,
ParseInfo,
RerankDocument,
RerankRequestDocumentsItem,
RerankResponse,
RerankResponseResultsItem,
RerankResponseResultsItemDocument,
RerankerDataMetrics,
ResponseFormat,
ResponseFormatV2,
SearchQueriesGenerationStreamedChatResponse,
SearchResultsStreamedChatResponse,
SingleGeneration,
SingleGenerationInStream,
SingleGenerationTokenLikelihoodsItem,
Source,
StreamEndGenerateStreamedResponse,
StreamEndStreamedChatResponse,
StreamErrorGenerateStreamedResponse,
StreamStartStreamedChatResponse,
StreamedChatResponse,
SummarizeRequestExtractiveness,
SummarizeRequestFormat,
SummarizeRequestLength,
SummarizeResponse,
SystemChatMessageV2,
SystemMessage,
SystemMessageV2,
SystemMessageV2Content,
SystemMessageV2ContentOneItem,
TextAssistantMessageResponseContentItem,
TextAssistantMessageV2ContentOneItem,
TextContent,
TextEmbedContent,
TextGenerationGenerateStreamedResponse,
TextGenerationStreamedChatResponse,
TextResponseFormat,
TextResponseFormatV2,
TextSystemMessageV2ContentOneItem,
TextToolContent,
Thinking,
ThinkingAssistantMessageResponseContentItem,
ThinkingAssistantMessageV2ContentOneItem,
ThinkingType,
TokenizeResponse,
Tool,
ToolCall,
ToolCallDelta,
ToolCallV2,
ToolCallV2Function,
ToolCallsChunkStreamedChatResponse,
ToolCallsGenerationStreamedChatResponse,
ToolChatMessageV2,
ToolContent,
ToolMessage,
ToolMessageV2,
ToolMessageV2Content,
ToolParameterDefinitionsValue,
ToolResult,
ToolSource,
ToolV2,
ToolV2Function,
UpdateConnectorResponse,
Usage,
UsageBilledUnits,
UsageTokens,
UserChatMessageV2,
UserMessage,
UserMessageV2,
UserMessageV2Content,
)
from .errors import (
BadRequestError,
ClientClosedRequestError,
ForbiddenError,
GatewayTimeoutError,
InternalServerError,
InvalidTokenError,
NotFoundError,
NotImplementedError,
ServiceUnavailableError,
TooManyRequestsError,
UnauthorizedError,
UnprocessableEntityError,
)
from . import audio, batches, connectors, datasets, embed_jobs, finetuning, models, v2
from ._default_clients import DefaultAioHttpClient, DefaultAsyncHttpxClient
from .aliases import (
ChatResponse,
ContentDeltaStreamedChatResponseV2,
ContentEndStreamedChatResponseV2,
ContentStartStreamedChatResponseV2,
MessageEndStreamedChatResponseV2,
MessageStartStreamedChatResponseV2,
StreamedChatResponseV2,
ToolCallDeltaStreamedChatResponseV2,
ToolCallEndStreamedChatResponseV2,
ToolCallStartStreamedChatResponseV2,
)
from .aws_client import AwsClient
from .batches import (
Batch,
BatchStatus,
CancelBatchResponse,
CreateBatchResponse,
GetBatchResponse,
ListBatchesResponse,
)
from .bedrock_client import BedrockClient, BedrockClientV2
from .client import AsyncClient, Client
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse
from .embed_jobs import CreateEmbedJobRequestTruncate
from .environment import ClientEnvironment
from .oci_client import OciClient, OciClientV2
from .sagemaker_client import SagemakerClient, SagemakerClientV2
from .v2 import (
CitationEndV2ChatStreamResponse,
CitationStartV2ChatStreamResponse,
ContentDeltaV2ChatStreamResponse,
ContentEndV2ChatStreamResponse,
ContentStartV2ChatStreamResponse,
DebugV2ChatStreamResponse,
MessageEndV2ChatStreamResponse,
MessageStartV2ChatStreamResponse,
ToolCallDeltaV2ChatStreamResponse,
ToolCallEndV2ChatStreamResponse,
ToolCallStartV2ChatStreamResponse,
ToolPlanDeltaV2ChatStreamResponse,
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
V2ChatRequestToolChoice,
V2ChatResponse,
V2ChatStreamRequestDocumentsItem,
V2ChatStreamRequestSafetyMode,
V2ChatStreamRequestToolChoice,
V2ChatStreamResponse,
V2EmbedRequestTruncate,
V2RerankResponse,
V2RerankResponseResultsItem,
)
from .version import __version__
_dynamic_imports: typing.Dict[str, str] = {
"ApiMeta": ".types",
"ApiMetaApiVersion": ".types",
"ApiMetaBilledUnits": ".types",
"ApiMetaTokens": ".types",
"AssistantChatMessageV2": ".types",
"AssistantMessage": ".types",
"AssistantMessageResponse": ".types",
"AssistantMessageResponseContentItem": ".types",
"AssistantMessageV2Content": ".types",
"AssistantMessageV2ContentOneItem": ".types",
"AsyncClient": ".client",
"AsyncClientV2": ".client_v2",
"AuthTokenType": ".types",
"AwsClient": ".aws_client",
"BadRequestError": ".errors",
"Batch": ".batches",
"BatchStatus": ".batches",
"BedrockClient": ".bedrock_client",
"BedrockClientV2": ".bedrock_client",
"CancelBatchResponse": ".batches",
"ChatCitation": ".types",
"ChatCitationGenerationEvent": ".types",
"ChatCitationType": ".types",
"ChatConnector": ".types",
"ChatContentDeltaEvent": ".types",
"ChatContentDeltaEventDelta": ".types",
"ChatContentDeltaEventDeltaMessage": ".types",
"ChatContentDeltaEventDeltaMessageContent": ".types",
"ChatContentEndEvent": ".types",
"ChatContentStartEvent": ".types",
"ChatContentStartEventDelta": ".types",
"ChatContentStartEventDeltaMessage": ".types",
"ChatContentStartEventDeltaMessageContent": ".types",
"ChatContentStartEventDeltaMessageContentType": ".types",
"ChatDataMetrics": ".types",
"ChatDebugEvent": ".types",
"ChatDocument": ".types",
"ChatDocumentSource": ".types",
"ChatFinishReason": ".types",
"ChatMessage": ".types",
"ChatMessageEndEvent": ".types",
"ChatMessageEndEventDelta": ".types",
"ChatMessageStartEvent": ".types",
"ChatMessageStartEventDelta": ".types",
"ChatMessageStartEventDeltaMessage": ".types",
"ChatMessageV2": ".types",
"ChatMessages": ".types",
"ChatRequestCitationQuality": ".types",
"ChatRequestPromptTruncation": ".types",
"ChatRequestSafetyMode": ".types",
"ChatResponse": ".aliases",
"ChatSearchQueriesGenerationEvent": ".types",
"ChatSearchQuery": ".types",
"ChatSearchResult": ".types",
"ChatSearchResultConnector": ".types",
"ChatSearchResultsEvent": ".types",
"ChatStreamEndEvent": ".types",
"ChatStreamEndEventFinishReason": ".types",
"ChatStreamEvent": ".types",
"ChatStreamEventType": ".types",
"ChatStreamRequestCitationQuality": ".types",
"ChatStreamRequestPromptTruncation": ".types",
"ChatStreamRequestSafetyMode": ".types",
"ChatStreamStartEvent": ".types",
"ChatTextContent": ".types",
"ChatTextGenerationEvent": ".types",
"ChatTextResponseFormat": ".types",
"ChatTextResponseFormatV2": ".types",
"ChatThinkingContent": ".types",
"ChatToolCallDeltaEvent": ".types",
"ChatToolCallDeltaEventDelta": ".types",
"ChatToolCallDeltaEventDeltaMessage": ".types",
"ChatToolCallDeltaEventDeltaMessageToolCalls": ".types",
"ChatToolCallDeltaEventDeltaMessageToolCallsFunction": ".types",
"ChatToolCallEndEvent": ".types",
"ChatToolCallStartEvent": ".types",
"ChatToolCallStartEventDelta": ".types",
"ChatToolCallStartEventDeltaMessage": ".types",
"ChatToolCallsChunkEvent": ".types",
"ChatToolCallsGenerationEvent": ".types",
"ChatToolMessage": ".types",
"ChatToolPlanDeltaEvent": ".types",
"ChatToolPlanDeltaEventDelta": ".types",
"ChatToolPlanDeltaEventDeltaMessage": ".types",
"ChatToolSource": ".types",
"ChatbotMessage": ".types",
"CheckApiKeyResponse": ".types",
"Citation": ".types",
"CitationEndEvent": ".types",
"CitationEndV2ChatStreamResponse": ".v2",
"CitationGenerationStreamedChatResponse": ".types",
"CitationOptions": ".types",
"CitationOptionsMode": ".types",
"CitationStartEvent": ".types",
"CitationStartEventDelta": ".types",
"CitationStartEventDeltaMessage": ".types",
"CitationStartV2ChatStreamResponse": ".v2",
"CitationType": ".types",
"ClassifyDataMetrics": ".types",
"ClassifyExample": ".types",
"ClassifyRequestTruncate": ".types",
"ClassifyResponse": ".types",
"ClassifyResponseClassificationsItem": ".types",
"ClassifyResponseClassificationsItemClassificationType": ".types",
"ClassifyResponseClassificationsItemLabelsValue": ".types",
"Client": ".client",
"ClientClosedRequestError": ".errors",
"ClientEnvironment": ".environment",
"ClientV2": ".client_v2",
"CompatibleEndpoint": ".types",
"Connector": ".types",
"ConnectorAuthStatus": ".types",
"ConnectorOAuth": ".types",
"Content": ".types",
"ContentDeltaStreamedChatResponseV2": ".aliases",
"ContentDeltaV2ChatStreamResponse": ".v2",
"ContentEndStreamedChatResponseV2": ".aliases",
"ContentEndV2ChatStreamResponse": ".v2",
"ContentStartStreamedChatResponseV2": ".aliases",
"ContentStartV2ChatStreamResponse": ".v2",
"CreateBatchResponse": ".batches",
"CreateConnectorOAuth": ".types",
"CreateConnectorResponse": ".types",
"CreateConnectorServiceAuth": ".types",
"CreateEmbedJobRequestTruncate": ".embed_jobs",
"CreateEmbedJobResponse": ".types",
"Dataset": ".types",
"DatasetPart": ".types",
"DatasetType": ".types",
"DatasetValidationStatus": ".types",
"DatasetsCreateResponse": ".datasets",
"DatasetsGetResponse": ".datasets",
"DatasetsGetUsageResponse": ".datasets",
"DatasetsListResponse": ".datasets",
"DebugStreamedChatResponse": ".types",
"DebugV2ChatStreamResponse": ".v2",
"DefaultAioHttpClient": "._default_clients",
"DefaultAsyncHttpxClient": "._default_clients",
"DeleteConnectorResponse": ".types",
"DetokenizeResponse": ".types",
"Document": ".types",
"DocumentContent": ".types",
"DocumentSource": ".types",
"DocumentToolContent": ".types",
"EmbedByTypeResponse": ".types",
"EmbedByTypeResponseEmbeddings": ".types",
"EmbedByTypeResponseResponseType": ".types",
"EmbedContent": ".types",
"EmbedFloatsResponse": ".types",
"EmbedImage": ".types",
"EmbedImageUrl": ".types",
"EmbedInput": ".types",
"EmbedInputType": ".types",
"EmbedJob": ".types",
"EmbedJobStatus": ".types",
"EmbedJobTruncate": ".types",
"EmbedRequestTruncate": ".types",
"EmbedResponse": ".types",
"EmbedText": ".types",
"EmbeddingType": ".types",
"EmbeddingsByTypeEmbedResponse": ".types",
"EmbeddingsFloatsEmbedResponse": ".types",
"FinetuneDatasetMetrics": ".types",
"FinishReason": ".types",
"ForbiddenError": ".errors",
"GatewayTimeoutError": ".errors",
"GenerateRequestReturnLikelihoods": ".types",
"GenerateRequestTruncate": ".types",
"GenerateStreamEnd": ".types",
"GenerateStreamEndResponse": ".types",
"GenerateStreamError": ".types",
"GenerateStreamEvent": ".types",
"GenerateStreamRequestReturnLikelihoods": ".types",
"GenerateStreamRequestTruncate": ".types",
"GenerateStreamText": ".types",
"GenerateStreamedResponse": ".types",
"Generation": ".types",
"GetBatchResponse": ".batches",
"GetConnectorResponse": ".types",
"GetModelResponse": ".types",
"GetModelResponseSamplingDefaults": ".types",
"Image": ".types",
"ImageContent": ".types",
"ImageUrl": ".types",
"ImageUrlContent": ".types",
"ImageUrlDetail": ".types",
"ImageUrlEmbedContent": ".types",
"InternalServerError": ".errors",
"InvalidTokenError": ".errors",
"JsonObjectResponseFormat": ".types",
"JsonObjectResponseFormatV2": ".types",
"JsonResponseFormat": ".types",
"JsonResponseFormatV2": ".types",
"LabelMetric": ".types",
"ListBatchesResponse": ".batches",
"ListConnectorsResponse": ".types",
"ListEmbedJobResponse": ".types",
"ListModelsResponse": ".types",
"LogprobItem": ".types",
"Message": ".types",
"MessageEndStreamedChatResponseV2": ".aliases",
"MessageEndV2ChatStreamResponse": ".v2",
"MessageStartStreamedChatResponseV2": ".aliases",
"MessageStartV2ChatStreamResponse": ".v2",
"Metrics": ".types",
"NonStreamedChatResponse": ".types",
"NotFoundError": ".errors",
"NotImplementedError": ".errors",
"OAuthAuthorizeResponse": ".types",
"OciClient": ".oci_client",
"OciClientV2": ".oci_client",
"ParseInfo": ".types",
"RerankDocument": ".types",
"RerankRequestDocumentsItem": ".types",
"RerankResponse": ".types",
"RerankResponseResultsItem": ".types",
"RerankResponseResultsItemDocument": ".types",
"RerankerDataMetrics": ".types",
"ResponseFormat": ".types",
"ResponseFormatV2": ".types",
"SagemakerClient": ".sagemaker_client",
"SagemakerClientV2": ".sagemaker_client",
"SearchQueriesGenerationStreamedChatResponse": ".types",
"SearchResultsStreamedChatResponse": ".types",
"ServiceUnavailableError": ".errors",
"SingleGeneration": ".types",
"SingleGenerationInStream": ".types",
"SingleGenerationTokenLikelihoodsItem": ".types",
"Source": ".types",
"StreamEndGenerateStreamedResponse": ".types",
"StreamEndStreamedChatResponse": ".types",
"StreamErrorGenerateStreamedResponse": ".types",
"StreamStartStreamedChatResponse": ".types",
"StreamedChatResponse": ".types",
"StreamedChatResponseV2": ".aliases",
"SummarizeRequestExtractiveness": ".types",
"SummarizeRequestFormat": ".types",
"SummarizeRequestLength": ".types",
"SummarizeResponse": ".types",
"SystemChatMessageV2": ".types",
"SystemMessage": ".types",
"SystemMessageV2": ".types",
"SystemMessageV2Content": ".types",
"SystemMessageV2ContentOneItem": ".types",
"TextAssistantMessageResponseContentItem": ".types",
"TextAssistantMessageV2ContentOneItem": ".types",
"TextContent": ".types",
"TextEmbedContent": ".types",
"TextGenerationGenerateStreamedResponse": ".types",
"TextGenerationStreamedChatResponse": ".types",
"TextResponseFormat": ".types",
"TextResponseFormatV2": ".types",
"TextSystemMessageV2ContentOneItem": ".types",
"TextToolContent": ".types",
"Thinking": ".types",
"ThinkingAssistantMessageResponseContentItem": ".types",
"ThinkingAssistantMessageV2ContentOneItem": ".types",
"ThinkingType": ".types",
"TokenizeResponse": ".types",
"TooManyRequestsError": ".errors",
"Tool": ".types",
"ToolCall": ".types",
"ToolCallDelta": ".types",
"ToolCallDeltaStreamedChatResponseV2": ".aliases",
"ToolCallDeltaV2ChatStreamResponse": ".v2",
"ToolCallEndStreamedChatResponseV2": ".aliases",
"ToolCallEndV2ChatStreamResponse": ".v2",
"ToolCallStartStreamedChatResponseV2": ".aliases",
"ToolCallStartV2ChatStreamResponse": ".v2",
"ToolCallV2": ".types",
"ToolCallV2Function": ".types",
"ToolCallsChunkStreamedChatResponse": ".types",
"ToolCallsGenerationStreamedChatResponse": ".types",
"ToolChatMessageV2": ".types",
"ToolContent": ".types",
"ToolMessage": ".types",
"ToolMessageV2": ".types",
"ToolMessageV2Content": ".types",
"ToolParameterDefinitionsValue": ".types",
"ToolPlanDeltaV2ChatStreamResponse": ".v2",
"ToolResult": ".types",
"ToolSource": ".types",
"ToolV2": ".types",
"ToolV2Function": ".types",
"UnauthorizedError": ".errors",
"UnprocessableEntityError": ".errors",
"UpdateConnectorResponse": ".types",
"Usage": ".types",
"UsageBilledUnits": ".types",
"UsageTokens": ".types",
"UserChatMessageV2": ".types",
"UserMessage": ".types",
"UserMessageV2": ".types",
"UserMessageV2Content": ".types",
"V2ChatRequestDocumentsItem": ".v2",
"V2ChatRequestSafetyMode": ".v2",
"V2ChatRequestToolChoice": ".v2",
"V2ChatResponse": ".v2",
"V2ChatStreamRequestDocumentsItem": ".v2",
"V2ChatStreamRequestSafetyMode": ".v2",
"V2ChatStreamRequestToolChoice": ".v2",
"V2ChatStreamResponse": ".v2",
"V2EmbedRequestTruncate": ".v2",
"V2RerankResponse": ".v2",
"V2RerankResponseResultsItem": ".v2",
"__version__": ".version",
"audio": ".audio",
"batches": ".batches",
"connectors": ".connectors",
"datasets": ".datasets",
"embed_jobs": ".embed_jobs",
"finetuning": ".finetuning",
"models": ".models",
"v2": ".v2",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"ApiMeta",
"ApiMetaApiVersion",
"ApiMetaBilledUnits",
"ApiMetaTokens",
"AssistantChatMessageV2",
"AssistantMessage",
"AssistantMessageResponse",
"AssistantMessageResponseContentItem",
"AssistantMessageV2Content",
"AssistantMessageV2ContentOneItem",
"AsyncClient",
"AsyncClientV2",
"AuthTokenType",
"AwsClient",
"BadRequestError",
"Batch",
"BatchStatus",
"BedrockClient",
"BedrockClientV2",
"CancelBatchResponse",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatCitationType",
"ChatConnector",
"ChatContentDeltaEvent",
"ChatContentDeltaEventDelta",
"ChatContentDeltaEventDeltaMessage",
"ChatContentDeltaEventDeltaMessageContent",
"ChatContentEndEvent",
"ChatContentStartEvent",
"ChatContentStartEventDelta",
"ChatContentStartEventDeltaMessage",
"ChatContentStartEventDeltaMessageContent",
"ChatContentStartEventDeltaMessageContentType",
"ChatDataMetrics",
"ChatDebugEvent",
"ChatDocument",
"ChatDocumentSource",
"ChatFinishReason",
"ChatMessage",
"ChatMessageEndEvent",
"ChatMessageEndEventDelta",
"ChatMessageStartEvent",
"ChatMessageStartEventDelta",
"ChatMessageStartEventDeltaMessage",
"ChatMessageV2",
"ChatMessages",
"ChatRequestCitationQuality",
"ChatRequestPromptTruncation",
"ChatRequestSafetyMode",
"ChatResponse",
"ChatSearchQueriesGenerationEvent",
"ChatSearchQuery",
"ChatSearchResult",
"ChatSearchResultConnector",
"ChatSearchResultsEvent",
"ChatStreamEndEvent",
"ChatStreamEndEventFinishReason",
"ChatStreamEvent",
"ChatStreamEventType",
"ChatStreamRequestCitationQuality",
"ChatStreamRequestPromptTruncation",
"ChatStreamRequestSafetyMode",
"ChatStreamStartEvent",
"ChatTextContent",
"ChatTextGenerationEvent",
"ChatTextResponseFormat",
"ChatTextResponseFormatV2",
"ChatThinkingContent",
"ChatToolCallDeltaEvent",
"ChatToolCallDeltaEventDelta",
"ChatToolCallDeltaEventDeltaMessage",
"ChatToolCallDeltaEventDeltaMessageToolCalls",
"ChatToolCallDeltaEventDeltaMessageToolCallsFunction",
"ChatToolCallEndEvent",
"ChatToolCallStartEvent",
"ChatToolCallStartEventDelta",
"ChatToolCallStartEventDeltaMessage",
"ChatToolCallsChunkEvent",
"ChatToolCallsGenerationEvent",
"ChatToolMessage",
"ChatToolPlanDeltaEvent",
"ChatToolPlanDeltaEventDelta",
"ChatToolPlanDeltaEventDeltaMessage",
"ChatToolSource",
"ChatbotMessage",
"CheckApiKeyResponse",
"Citation",
"CitationEndEvent",
"CitationEndV2ChatStreamResponse",
"CitationGenerationStreamedChatResponse",
"CitationOptions",
"CitationOptionsMode",
"CitationStartEvent",
"CitationStartEventDelta",
"CitationStartEventDeltaMessage",
"CitationStartV2ChatStreamResponse",
"CitationType",
"ClassifyDataMetrics",
"ClassifyExample",
"ClassifyRequestTruncate",
"ClassifyResponse",
"ClassifyResponseClassificationsItem",
"ClassifyResponseClassificationsItemClassificationType",
"ClassifyResponseClassificationsItemLabelsValue",
"Client",
"ClientClosedRequestError",
"ClientEnvironment",
"ClientV2",
"CompatibleEndpoint",
"Connector",
"ConnectorAuthStatus",
"ConnectorOAuth",
"Content",
"ContentDeltaStreamedChatResponseV2",
"ContentDeltaV2ChatStreamResponse",
"ContentEndStreamedChatResponseV2",
"ContentEndV2ChatStreamResponse",
"ContentStartStreamedChatResponseV2",
"ContentStartV2ChatStreamResponse",
"CreateBatchResponse",
"CreateConnectorOAuth",
"CreateConnectorResponse",
"CreateConnectorServiceAuth",
"CreateEmbedJobRequestTruncate",
"CreateEmbedJobResponse",
"Dataset",
"DatasetPart",
"DatasetType",
"DatasetValidationStatus",
"DatasetsCreateResponse",
"DatasetsGetResponse",
"DatasetsGetUsageResponse",
"DatasetsListResponse",
"DebugStreamedChatResponse",
"DebugV2ChatStreamResponse",
"DefaultAioHttpClient",
"DefaultAsyncHttpxClient",
"DeleteConnectorResponse",
"DetokenizeResponse",
"Document",
"DocumentContent",
"DocumentSource",
"DocumentToolContent",
"EmbedByTypeResponse",
"EmbedByTypeResponseEmbeddings",
"EmbedByTypeResponseResponseType",
"EmbedContent",
"EmbedFloatsResponse",
"EmbedImage",
"EmbedImageUrl",
"EmbedInput",
"EmbedInputType",
"EmbedJob",
"EmbedJobStatus",
"EmbedJobTruncate",
"EmbedRequestTruncate",
"EmbedResponse",
"EmbedText",
"EmbeddingType",
"EmbeddingsByTypeEmbedResponse",
"EmbeddingsFloatsEmbedResponse",
"FinetuneDatasetMetrics",
"FinishReason",
"ForbiddenError",
"GatewayTimeoutError",
"GenerateRequestReturnLikelihoods",
"GenerateRequestTruncate",
"GenerateStreamEnd",
"GenerateStreamEndResponse",
"GenerateStreamError",
"GenerateStreamEvent",
"GenerateStreamRequestReturnLikelihoods",
"GenerateStreamRequestTruncate",
"GenerateStreamText",
"GenerateStreamedResponse",
"Generation",
"GetBatchResponse",
"GetConnectorResponse",
"GetModelResponse",
"GetModelResponseSamplingDefaults",
"Image",
"ImageContent",
"ImageUrl",
"ImageUrlContent",
"ImageUrlDetail",
"ImageUrlEmbedContent",
"InternalServerError",
"InvalidTokenError",
"JsonObjectResponseFormat",
"JsonObjectResponseFormatV2",
"JsonResponseFormat",
"JsonResponseFormatV2",
"LabelMetric",
"ListBatchesResponse",
"ListConnectorsResponse",
"ListEmbedJobResponse",
"ListModelsResponse",
"LogprobItem",
"Message",
"MessageEndStreamedChatResponseV2",
"MessageEndV2ChatStreamResponse",
"MessageStartStreamedChatResponseV2",
"MessageStartV2ChatStreamResponse",
"Metrics",
"NonStreamedChatResponse",
"NotFoundError",
"NotImplementedError",
"OAuthAuthorizeResponse",
"OciClient",
"OciClientV2",
"ParseInfo",
"RerankDocument",
"RerankRequestDocumentsItem",
"RerankResponse",
"RerankResponseResultsItem",
"RerankResponseResultsItemDocument",
"RerankerDataMetrics",
"ResponseFormat",
"ResponseFormatV2",
"SagemakerClient",
"SagemakerClientV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"ServiceUnavailableError",
"SingleGeneration",
"SingleGenerationInStream",
"SingleGenerationTokenLikelihoodsItem",
"Source",
"StreamEndGenerateStreamedResponse",
"StreamEndStreamedChatResponse",
"StreamErrorGenerateStreamedResponse",
"StreamStartStreamedChatResponse",
"StreamedChatResponse",
"StreamedChatResponseV2",
"SummarizeRequestExtractiveness",
"SummarizeRequestFormat",
"SummarizeRequestLength",
"SummarizeResponse",
"SystemChatMessageV2",
"SystemMessage",
"SystemMessageV2",
"SystemMessageV2Content",
"SystemMessageV2ContentOneItem",
"TextAssistantMessageResponseContentItem",
"TextAssistantMessageV2ContentOneItem",
"TextContent",
"TextEmbedContent",
"TextGenerationGenerateStreamedResponse",
"TextGenerationStreamedChatResponse",
"TextResponseFormat",
"TextResponseFormatV2",
"TextSystemMessageV2ContentOneItem",
"TextToolContent",
"Thinking",
"ThinkingAssistantMessageResponseContentItem",
"ThinkingAssistantMessageV2ContentOneItem",
"ThinkingType",
"TokenizeResponse",
"TooManyRequestsError",
"Tool",
"ToolCall",
"ToolCallDelta",
"ToolCallDeltaStreamedChatResponseV2",
"ToolCallDeltaV2ChatStreamResponse",
"ToolCallEndStreamedChatResponseV2",
"ToolCallEndV2ChatStreamResponse",
"ToolCallStartStreamedChatResponseV2",
"ToolCallStartV2ChatStreamResponse",
"ToolCallV2",
"ToolCallV2Function",
"ToolCallsChunkStreamedChatResponse",
"ToolCallsGenerationStreamedChatResponse",
"ToolChatMessageV2",
"ToolContent",
"ToolMessage",
"ToolMessageV2",
"ToolMessageV2Content",
"ToolParameterDefinitionsValue",
"ToolPlanDeltaV2ChatStreamResponse",
"ToolResult",
"ToolSource",
"ToolV2",
"ToolV2Function",
"UnauthorizedError",
"UnprocessableEntityError",
"UpdateConnectorResponse",
"Usage",
"UsageBilledUnits",
"UsageTokens",
"UserChatMessageV2",
"UserMessage",
"UserMessageV2",
"UserMessageV2Content",
"V2ChatRequestDocumentsItem",
"V2ChatRequestSafetyMode",
"V2ChatRequestToolChoice",
"V2ChatResponse",
"V2ChatStreamRequestDocumentsItem",
"V2ChatStreamRequestSafetyMode",
"V2ChatStreamRequestToolChoice",
"V2ChatStreamResponse",
"V2EmbedRequestTruncate",
"V2RerankResponse",
"V2RerankResponseResultsItem",
"__version__",
"audio",
"batches",
"connectors",
"datasets",
"embed_jobs",
"finetuning",
"models",
"v2",
]
================================================
FILE: src/cohere/_default_clients.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import httpx
SDK_DEFAULT_TIMEOUT = 60
try:
import httpx_aiohttp # type: ignore[import-not-found]
except ImportError:
class DefaultAioHttpClient(httpx.AsyncClient): # type: ignore
def __init__(self, **kwargs: typing.Any) -> None:
raise RuntimeError("To use the aiohttp client, install the aiohttp extra: pip install cohere[aiohttp]")
else:
class DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore
def __init__(self, **kwargs: typing.Any) -> None:
kwargs.setdefault("timeout", SDK_DEFAULT_TIMEOUT)
kwargs.setdefault("follow_redirects", True)
super().__init__(**kwargs)
class DefaultAsyncHttpxClient(httpx.AsyncClient):
def __init__(self, **kwargs: typing.Any) -> None:
kwargs.setdefault("timeout", SDK_DEFAULT_TIMEOUT)
kwargs.setdefault("follow_redirects", True)
super().__init__(**kwargs)
================================================
FILE: src/cohere/aliases.py
================================================
# Import overrides early to ensure they're applied before types are used
# This is necessary for backwards compatibility patches like ToolCallV2.id being optional
from . import overrides # noqa: F401
from .v2 import (
ContentDeltaV2ChatStreamResponse,
ContentEndV2ChatStreamResponse,
ContentStartV2ChatStreamResponse,
MessageEndV2ChatStreamResponse,
MessageStartV2ChatStreamResponse,
ToolCallDeltaV2ChatStreamResponse,
ToolCallEndV2ChatStreamResponse,
ToolCallStartV2ChatStreamResponse,
V2ChatStreamResponse,
V2ChatResponse
)
# alias classes
StreamedChatResponseV2 = V2ChatStreamResponse
MessageStartStreamedChatResponseV2 = MessageStartV2ChatStreamResponse
MessageEndStreamedChatResponseV2 = MessageEndV2ChatStreamResponse
ContentStartStreamedChatResponseV2 = ContentStartV2ChatStreamResponse
ContentDeltaStreamedChatResponseV2 = ContentDeltaV2ChatStreamResponse
ContentEndStreamedChatResponseV2 = ContentEndV2ChatStreamResponse
ToolCallStartStreamedChatResponseV2 = ToolCallStartV2ChatStreamResponse
ToolCallDeltaStreamedChatResponseV2 = ToolCallDeltaV2ChatStreamResponse
ToolCallEndStreamedChatResponseV2 = ToolCallEndV2ChatStreamResponse
ChatResponse = V2ChatResponse
================================================
FILE: src/cohere/audio/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from . import transcriptions
from .transcriptions import AudioTranscriptionsCreateResponse
_dynamic_imports: typing.Dict[str, str] = {
"AudioTranscriptionsCreateResponse": ".transcriptions",
"transcriptions": ".transcriptions",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["AudioTranscriptionsCreateResponse", "transcriptions"]
================================================
FILE: src/cohere/audio/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from .raw_client import AsyncRawAudioClient, RawAudioClient
if typing.TYPE_CHECKING:
from .transcriptions.client import AsyncTranscriptionsClient, TranscriptionsClient
class AudioClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawAudioClient(client_wrapper=client_wrapper)
self._client_wrapper = client_wrapper
self._transcriptions: typing.Optional[TranscriptionsClient] = None
@property
def with_raw_response(self) -> RawAudioClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawAudioClient
"""
return self._raw_client
@property
def transcriptions(self):
if self._transcriptions is None:
from .transcriptions.client import TranscriptionsClient # noqa: E402
self._transcriptions = TranscriptionsClient(client_wrapper=self._client_wrapper)
return self._transcriptions
class AsyncAudioClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawAudioClient(client_wrapper=client_wrapper)
self._client_wrapper = client_wrapper
self._transcriptions: typing.Optional[AsyncTranscriptionsClient] = None
@property
def with_raw_response(self) -> AsyncRawAudioClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawAudioClient
"""
return self._raw_client
@property
def transcriptions(self):
if self._transcriptions is None:
from .transcriptions.client import AsyncTranscriptionsClient # noqa: E402
self._transcriptions = AsyncTranscriptionsClient(client_wrapper=self._client_wrapper)
return self._transcriptions
================================================
FILE: src/cohere/audio/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
class RawAudioClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
class AsyncRawAudioClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
================================================
FILE: src/cohere/audio/transcriptions/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import AudioTranscriptionsCreateResponse
_dynamic_imports: typing.Dict[str, str] = {"AudioTranscriptionsCreateResponse": ".types"}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["AudioTranscriptionsCreateResponse"]
================================================
FILE: src/cohere/audio/transcriptions/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ... import core
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ...core.request_options import RequestOptions
from .raw_client import AsyncRawTranscriptionsClient, RawTranscriptionsClient
from .types.audio_transcriptions_create_response import AudioTranscriptionsCreateResponse
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class TranscriptionsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawTranscriptionsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawTranscriptionsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawTranscriptionsClient
"""
return self._raw_client
def create(
self,
*,
model: str,
language: str,
file: core.File,
temperature: typing.Optional[float] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AudioTranscriptionsCreateResponse:
"""
Transcribe an audio file.
Parameters
----------
model : str
ID of the model to use.
language : str
The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
file : core.File
See core.File for more documentation
temperature : typing.Optional[float]
The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AudioTranscriptionsCreateResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.audio.transcriptions.create(
model="model",
language="language",
)
"""
_response = self._raw_client.create(
model=model, language=language, file=file, temperature=temperature, request_options=request_options
)
return _response.data
class AsyncTranscriptionsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawTranscriptionsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawTranscriptionsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawTranscriptionsClient
"""
return self._raw_client
async def create(
self,
*,
model: str,
language: str,
file: core.File,
temperature: typing.Optional[float] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AudioTranscriptionsCreateResponse:
"""
Transcribe an audio file.
Parameters
----------
model : str
ID of the model to use.
language : str
The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
file : core.File
See core.File for more documentation
temperature : typing.Optional[float]
The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AudioTranscriptionsCreateResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.audio.transcriptions.create(
model="model",
language="language",
)
asyncio.run(main())
"""
_response = await self._raw_client.create(
model=model, language=language, file=file, temperature=temperature, request_options=request_options
)
return _response.data
================================================
FILE: src/cohere/audio/transcriptions/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ... import core
from ...core.api_error import ApiError
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ...core.http_response import AsyncHttpResponse, HttpResponse
from ...core.parse_error import ParsingError
from ...core.request_options import RequestOptions
from ...core.unchecked_base_model import construct_type
from ...errors.bad_request_error import BadRequestError
from ...errors.client_closed_request_error import ClientClosedRequestError
from ...errors.forbidden_error import ForbiddenError
from ...errors.gateway_timeout_error import GatewayTimeoutError
from ...errors.internal_server_error import InternalServerError
from ...errors.invalid_token_error import InvalidTokenError
from ...errors.not_found_error import NotFoundError
from ...errors.not_implemented_error import NotImplementedError
from ...errors.service_unavailable_error import ServiceUnavailableError
from ...errors.too_many_requests_error import TooManyRequestsError
from ...errors.unauthorized_error import UnauthorizedError
from ...errors.unprocessable_entity_error import UnprocessableEntityError
from .types.audio_transcriptions_create_response import AudioTranscriptionsCreateResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawTranscriptionsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def create(
self,
*,
model: str,
language: str,
file: core.File,
temperature: typing.Optional[float] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[AudioTranscriptionsCreateResponse]:
"""
Transcribe an audio file.
Parameters
----------
model : str
ID of the model to use.
language : str
The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
file : core.File
See core.File for more documentation
temperature : typing.Optional[float]
The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[AudioTranscriptionsCreateResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v2/audio/transcriptions",
method="POST",
data={
"model": model,
"language": language,
"temperature": temperature,
},
files={
"file": file,
},
request_options=request_options,
omit=OMIT,
force_multipart=True,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
AudioTranscriptionsCreateResponse,
construct_type(
type_=AudioTranscriptionsCreateResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawTranscriptionsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def create(
self,
*,
model: str,
language: str,
file: core.File,
temperature: typing.Optional[float] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[AudioTranscriptionsCreateResponse]:
"""
Transcribe an audio file.
Parameters
----------
model : str
ID of the model to use.
language : str
The language of the input audio, supplied in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format.
file : core.File
See core.File for more documentation
temperature : typing.Optional[float]
The sampling temperature, between 0 and 1. Higher values like 0.8 make the output more random, while lower values like 0.2 make it more focused and deterministic.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[AudioTranscriptionsCreateResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/audio/transcriptions",
method="POST",
data={
"model": model,
"language": language,
"temperature": temperature,
},
files={
"file": file,
},
request_options=request_options,
omit=OMIT,
force_multipart=True,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
AudioTranscriptionsCreateResponse,
construct_type(
type_=AudioTranscriptionsCreateResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/audio/transcriptions/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .audio_transcriptions_create_response import AudioTranscriptionsCreateResponse
_dynamic_imports: typing.Dict[str, str] = {"AudioTranscriptionsCreateResponse": ".audio_transcriptions_create_response"}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["AudioTranscriptionsCreateResponse"]
================================================
FILE: src/cohere/audio/transcriptions/types/audio_transcriptions_create_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
class AudioTranscriptionsCreateResponse(UncheckedBaseModel):
text: str = pydantic.Field()
"""
The transcribed text.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/aws_client.py
================================================
import base64
import json
import re
import typing
import httpx
from httpx import URL, SyncByteStream, ByteStream
from . import GenerateStreamedResponse, Generation, \
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
ApiMetaBilledUnits
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
from .client_v2 import ClientV2
class AwsClient(Client):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)
class AwsClientV2(ClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
ClientV2.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)
EventHook = typing.Callable[..., typing.Any]
def get_event_hooks(
service: str,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> typing.Dict[str, typing.List[EventHook]]:
return {
"request": [
map_request_to_bedrock(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
],
"response": [
map_response_from_bedrock()
],
}
TextGeneration = typing.TypedDict('TextGeneration',
{"text": str, "is_finished": str, "event_type": typing.Literal["text-generation"]})
StreamEnd = typing.TypedDict('StreamEnd',
{"is_finished": str, "event_type": typing.Literal["stream-end"], "finish_reason": str,
# "amazon-bedrock-invocationMetrics": {
# "inputTokenCount": int, "outputTokenCount": int, "invocationLatency": int,
# "firstByteLatency": int}
})
class Streamer(SyncByteStream):
lines: typing.Iterator[bytes]
def __init__(self, lines: typing.Iterator[bytes]):
self.lines = lines
def __iter__(self) -> typing.Iterator[bytes]:
return self.lines
response_mapping: typing.Dict[str, typing.Any] = {
"chat": NonStreamedChatResponse,
"embed": EmbedResponse,
"generate": Generation,
"rerank": RerankResponse
}
stream_response_mapping: typing.Dict[str, typing.Any] = {
"chat": StreamedChatResponse,
"generate": GenerateStreamedResponse,
}
def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator[bytes]:
regex = r"{[^\}]*}"
for _text in response.iter_lines():
match = re.search(regex, _text)
if match:
obj = json.loads(match.group())
if "bytes" in obj:
base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8")
streamed_obj = json.loads(base64_payload)
if "event_type" in streamed_obj:
response_type = stream_response_mapping[endpoint]
parsed = typing.cast(response_type, # type: ignore
construct_type(type_=response_type, object_=streamed_obj))
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
def map_token_counts(response: httpx.Response) -> ApiMeta:
input_tokens = int(response.headers.get("X-Amzn-Bedrock-Input-Token-Count", -1))
output_tokens = int(response.headers.get("X-Amzn-Bedrock-Output-Token-Count", -1))
return ApiMeta(
tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens),
billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens),
)
def map_response_from_bedrock():
def _hook(
response: httpx.Response,
) -> None:
stream = response.headers["content-type"] == "application/vnd.amazon.eventstream"
endpoint = response.request.extensions["endpoint"]
output: typing.Iterator[bytes]
if stream:
output = stream_generator(httpx.Response(
stream=response.stream,
status_code=response.status_code,
), endpoint)
else:
response_type = response_mapping[endpoint]
response_obj = json.loads(response.read())
response_obj["meta"] = map_token_counts(response).dict()
cast_obj: typing.Any = typing.cast(response_type, # type: ignore
construct_type(
type_=response_type,
# type: ignore
object_=response_obj))
output = iter([json.dumps(cast_obj.dict()).encode("utf-8")])
response.stream = Streamer(output)
# reset response object to allow for re-reading
if hasattr(response, "_content"):
del response._content
response.is_stream_consumed = False
response.is_closed = False
return _hook
def get_boto3_session(
**kwargs: typing.Any,
):
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
return lazy_boto3().Session(**non_none_args)
def map_request_to_bedrock(
service: str,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = get_boto3_session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
aws_region = session.region_name
credentials = session.get_credentials()
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]
api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
url = get_url(
platform=service,
aws_region=aws_region,
model=model, # type: ignore
stream="stream" in body and body["stream"],
)
request.url = URL(url)
request.headers["host"] = request.url.host
headers["host"] = request.url.host
if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
if "stream" in body:
del body["stream"]
if "model" in body:
del body["model"]
new_body = json.dumps(body).encode("utf-8")
request.stream = ByteStream(new_body)
request._content = new_body
headers["content-length"] = str(len(new_body))
aws_request = lazy_botocore().awsrequest.AWSRequest(
method=request.method,
url=url,
headers=headers,
data=request.read(),
)
signer.add_auth(aws_request)
request.headers = httpx.Headers(aws_request.prepare().headers)
request.extensions["endpoint"] = endpoint
return _event_hook
def get_url(
*,
platform: str,
aws_region: typing.Optional[str],
model: str,
stream: bool,
) -> str:
if platform == "bedrock":
endpoint = "invoke" if not stream else "invoke-with-response-stream"
return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}"
elif platform == "sagemaker":
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""
def get_api_version(*, version: str):
int_version = {
"v1": 1,
"v2": 2,
}
return int_version.get(version, 1)
================================================
FILE: src/cohere/base_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import os
import typing
import httpx
from .core.api_error import ApiError
from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from .core.logging import LogConfig, Logger
from .core.request_options import RequestOptions
from .environment import ClientEnvironment
from .raw_base_client import AsyncRawBaseCohere, RawBaseCohere
from .types.chat_connector import ChatConnector
from .types.chat_document import ChatDocument
from .types.chat_request_citation_quality import ChatRequestCitationQuality
from .types.chat_request_prompt_truncation import ChatRequestPromptTruncation
from .types.chat_request_safety_mode import ChatRequestSafetyMode
from .types.chat_stream_request_citation_quality import ChatStreamRequestCitationQuality
from .types.chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation
from .types.chat_stream_request_safety_mode import ChatStreamRequestSafetyMode
from .types.check_api_key_response import CheckApiKeyResponse
from .types.classify_example import ClassifyExample
from .types.classify_request_truncate import ClassifyRequestTruncate
from .types.classify_response import ClassifyResponse
from .types.detokenize_response import DetokenizeResponse
from .types.embed_input_type import EmbedInputType
from .types.embed_request_truncate import EmbedRequestTruncate
from .types.embed_response import EmbedResponse
from .types.embedding_type import EmbeddingType
from .types.generate_request_return_likelihoods import GenerateRequestReturnLikelihoods
from .types.generate_request_truncate import GenerateRequestTruncate
from .types.generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods
from .types.generate_stream_request_truncate import GenerateStreamRequestTruncate
from .types.generate_streamed_response import GenerateStreamedResponse
from .types.generation import Generation
from .types.message import Message
from .types.non_streamed_chat_response import NonStreamedChatResponse
from .types.rerank_request_documents_item import RerankRequestDocumentsItem
from .types.rerank_response import RerankResponse
from .types.response_format import ResponseFormat
from .types.streamed_chat_response import StreamedChatResponse
from .types.summarize_request_extractiveness import SummarizeRequestExtractiveness
from .types.summarize_request_format import SummarizeRequestFormat
from .types.summarize_request_length import SummarizeRequestLength
from .types.summarize_response import SummarizeResponse
from .types.tokenize_response import TokenizeResponse
from .types.tool import Tool
from .types.tool_result import ToolResult
if typing.TYPE_CHECKING:
from .audio.client import AsyncAudioClient, AudioClient
from .batches.client import AsyncBatchesClient, BatchesClient
from .connectors.client import AsyncConnectorsClient, ConnectorsClient
from .datasets.client import AsyncDatasetsClient, DatasetsClient
from .embed_jobs.client import AsyncEmbedJobsClient, EmbedJobsClient
from .finetuning.client import AsyncFinetuningClient, FinetuningClient
from .models.client import AsyncModelsClient, ModelsClient
from .v2.client import AsyncV2Client, V2Client
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class BaseCohere:
"""
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions.
Parameters
----------
base_url : typing.Optional[str]
The base url to use for requests from the client.
environment : ClientEnvironment
The environment to use for requests from the client. from .environment import ClientEnvironment
Defaults to ClientEnvironment.PRODUCTION
client_name : typing.Optional[str]
token : typing.Optional[typing.Union[str, typing.Callable[[], str]]]
headers : typing.Optional[typing.Dict[str, str]]
Additional headers to send with every request.
timeout : typing.Optional[float]
The timeout to be used, in seconds, for requests. By default the timeout is 300 seconds, unless a custom httpx client is used, in which case this default is not enforced.
follow_redirects : typing.Optional[bool]
Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in.
httpx_client : typing.Optional[httpx.Client]
The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration.
logging : typing.Optional[typing.Union[LogConfig, Logger]]
Configure logging for the SDK. Accepts a LogConfig dict with 'level' (debug/info/warn/error), 'logger' (custom logger implementation), and 'silent' (boolean, defaults to True) fields. You can also pass a pre-configured Logger instance.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
"""
def __init__(
self,
*,
base_url: typing.Optional[str] = None,
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
token: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = os.getenv("CO_API_KEY"),
headers: typing.Optional[typing.Dict[str, str]] = None,
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.Client] = None,
logging: typing.Optional[typing.Union[LogConfig, Logger]] = None,
):
_defaulted_timeout = (
timeout if timeout is not None else 300 if httpx_client is None else httpx_client.timeout.read
)
if token is None:
raise ApiError(body="The client must be instantiated be either passing in token or setting CO_API_KEY")
self._client_wrapper = SyncClientWrapper(
base_url=_get_base_url(base_url=base_url, environment=environment),
client_name=client_name,
token=token,
headers=headers,
httpx_client=httpx_client
if httpx_client is not None
else httpx.Client(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
if follow_redirects is not None
else httpx.Client(timeout=_defaulted_timeout),
timeout=_defaulted_timeout,
logging=logging,
)
self._raw_client = RawBaseCohere(client_wrapper=self._client_wrapper)
self._v2: typing.Optional[V2Client] = None
self._batches: typing.Optional[BatchesClient] = None
self._embed_jobs: typing.Optional[EmbedJobsClient] = None
self._datasets: typing.Optional[DatasetsClient] = None
self._connectors: typing.Optional[ConnectorsClient] = None
self._models: typing.Optional[ModelsClient] = None
self._finetuning: typing.Optional[FinetuningClient] = None
self._audio: typing.Optional[AudioClient] = None
@property
def with_raw_response(self) -> RawBaseCohere:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawBaseCohere
"""
return self._raw_client
def chat_stream(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[StreamedChatResponse]:
"""
Generates a streamed text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatStreamRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[StreamedChatResponse]
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
response = client.chat_stream(
model="command-a-03-2025",
message="hello!",
)
for chunk in response:
yield chunk
"""
with self._raw_client.chat_stream(
message=message,
accepts=accepts,
model=model,
preamble=preamble,
chat_history=chat_history,
conversation_id=conversation_id,
prompt_truncation=prompt_truncation,
connectors=connectors,
search_queries_only=search_queries_only,
documents=documents,
citation_quality=citation_quality,
temperature=temperature,
max_tokens=max_tokens,
max_input_tokens=max_input_tokens,
k=k,
p=p,
seed=seed,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
raw_prompting=raw_prompting,
tools=tools,
tool_results=tool_results,
force_single_step=force_single_step,
response_format=response_format,
safety_mode=safety_mode,
request_options=request_options,
) as r:
yield from r.data
def chat(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> NonStreamedChatResponse:
"""
Generates a text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
NonStreamedChatResponse
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.chat(
model="command-a-03-2025",
message="Tell me about LLMs",
)
"""
_response = self._raw_client.chat(
message=message,
accepts=accepts,
model=model,
preamble=preamble,
chat_history=chat_history,
conversation_id=conversation_id,
prompt_truncation=prompt_truncation,
connectors=connectors,
search_queries_only=search_queries_only,
documents=documents,
citation_quality=citation_quality,
temperature=temperature,
max_tokens=max_tokens,
max_input_tokens=max_input_tokens,
k=k,
p=p,
seed=seed,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
raw_prompting=raw_prompting,
tools=tools,
tool_results=tool_results,
force_single_step=force_single_step,
response_format=response_format,
safety_mode=safety_mode,
request_options=request_options,
)
return _response.data
def generate_stream(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[GenerateStreamedResponse]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateStreamRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[GenerateStreamedResponse]
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
response = client.generate_stream(
prompt="Please explain to me how LLMs work",
)
for chunk in response:
yield chunk
"""
with self._raw_client.generate_stream(
prompt=prompt,
model=model,
num_generations=num_generations,
max_tokens=max_tokens,
truncate=truncate,
temperature=temperature,
seed=seed,
preset=preset,
end_sequences=end_sequences,
stop_sequences=stop_sequences,
k=k,
p=p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
return_likelihoods=return_likelihoods,
raw_prompting=raw_prompting,
request_options=request_options,
) as r:
yield from r.data
def generate(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> Generation:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
Generation
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.generate(
prompt="Please explain to me how LLMs work",
)
"""
_response = self._raw_client.generate(
prompt=prompt,
model=model,
num_generations=num_generations,
max_tokens=max_tokens,
truncate=truncate,
temperature=temperature,
seed=seed,
preset=preset,
end_sequences=end_sequences,
stop_sequences=stop_sequences,
k=k,
p=p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
return_likelihoods=return_likelihoods,
raw_prompting=raw_prompting,
request_options=request_options,
)
return _response.data
def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> EmbedResponse:
"""
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents.
Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Images are only supported with Embed v3.0 and newer models.
model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : typing.Optional[EmbedInputType]
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.embed(
texts=["hello", "goodbye"],
model="embed-v4.0",
input_type="classification",
)
"""
_response = self._raw_client.embed(
texts=texts,
images=images,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
return _response.data
def rerank(
self,
*,
query: str,
documents: typing.Sequence[RerankRequestDocumentsItem],
model: typing.Optional[str] = OMIT,
top_n: typing.Optional[int] = OMIT,
rank_fields: typing.Optional[typing.Sequence[str]] = OMIT,
return_documents: typing.Optional[bool] = OMIT,
max_chunks_per_doc: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> RerankResponse:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
query : str
The search query
documents : typing.Sequence[RerankRequestDocumentsItem]
A list of document objects or strings to rerank.
If a document is provided the text fields is required and all other fields will be preserved in the response.
The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
rank_fields : typing.Optional[typing.Sequence[str]]
If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
return_documents : typing.Optional[bool]
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
max_chunks_per_doc : typing.Optional[int]
The maximum number of chunks to produce internally from a document
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
RerankResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.rerank(
documents=[
{
"text": "Carson City is the capital city of the American state of Nevada."
},
{
"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
},
{
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
{
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
{
"text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
},
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
"""
_response = self._raw_client.rerank(
query=query,
documents=documents,
model=model,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
request_options=request_options,
)
return _response.data
def classify(
self,
*,
inputs: typing.Sequence[str],
examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT,
model: typing.Optional[str] = OMIT,
preset: typing.Optional[str] = OMIT,
truncate: typing.Optional[ClassifyRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> ClassifyResponse:
"""
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference.
Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
Parameters
----------
inputs : typing.Sequence[str]
A list of up to 96 texts to be classified. Each one must be a non-empty string.
There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models).
Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
examples : typing.Optional[typing.Sequence[ClassifyExample]]
An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`.
Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
model : typing.Optional[str]
ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
preset : typing.Optional[str]
The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
truncate : typing.Optional[ClassifyRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ClassifyResponse
OK
Examples
--------
from cohere import ClassifyExample, Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.classify(
examples=[
ClassifyExample(
text="Dermatologists don't like her!",
label="Spam",
),
ClassifyExample(
text="'Hello, open to this?'",
label="Spam",
),
ClassifyExample(
text="I need help please wire me $1000 right now",
label="Spam",
),
ClassifyExample(
text="Nice to know you ;)",
label="Spam",
),
ClassifyExample(
text="Please help me?",
label="Spam",
),
ClassifyExample(
text="Your parcel will be delivered today",
label="Not spam",
),
ClassifyExample(
text="Review changes to our Terms and Conditions",
label="Not spam",
),
ClassifyExample(
text="Weekly sync notes",
label="Not spam",
),
ClassifyExample(
text="'Re: Follow up from today's meeting'",
label="Not spam",
),
ClassifyExample(
text="Pre-read for tomorrow",
label="Not spam",
),
],
inputs=["Confirm your email address", "hey i need u to send some $"],
model="YOUR-FINE-TUNED-MODEL-ID",
)
"""
_response = self._raw_client.classify(
inputs=inputs,
examples=examples,
model=model,
preset=preset,
truncate=truncate,
request_options=request_options,
)
return _response.data
def summarize(
self,
*,
text: str,
length: typing.Optional[SummarizeRequestLength] = OMIT,
format: typing.Optional[SummarizeRequestFormat] = OMIT,
model: typing.Optional[str] = OMIT,
extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT,
temperature: typing.Optional[float] = OMIT,
additional_command: typing.Optional[str] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> SummarizeResponse:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates a summary in English for a given text.
Parameters
----------
text : str
The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
length : typing.Optional[SummarizeRequestLength]
One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
format : typing.Optional[SummarizeRequestFormat]
One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
model : typing.Optional[str]
The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
extractiveness : typing.Optional[SummarizeRequestExtractiveness]
One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
temperature : typing.Optional[float]
Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
additional_command : typing.Optional[str]
A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
SummarizeResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.summarize(
text='Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name "ice cream" varies from one country to another. In some countries, such as the United States, "ice cream" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled "frozen dairy dessert" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.',
)
"""
_response = self._raw_client.summarize(
text=text,
length=length,
format=format,
model=model,
extractiveness=extractiveness,
temperature=temperature,
additional_command=additional_command,
request_options=request_options,
)
return _response.data
def tokenize(
self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None
) -> TokenizeResponse:
"""
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
text : str
The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
model : str
The input will be tokenized by the tokenizer that is used by this model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
TokenizeResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.tokenize(
text="tokenize me! :D",
model="command",
)
"""
_response = self._raw_client.tokenize(text=text, model=model, request_options=request_options)
return _response.data
def detokenize(
self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None
) -> DetokenizeResponse:
"""
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
tokens : typing.Sequence[int]
The list of tokens to be detokenized.
model : str
An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DetokenizeResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.detokenize(
tokens=[10002, 2261, 2012, 8, 2792, 43],
model="command",
)
"""
_response = self._raw_client.detokenize(tokens=tokens, model=model, request_options=request_options)
return _response.data
def check_api_key(self, *, request_options: typing.Optional[RequestOptions] = None) -> CheckApiKeyResponse:
"""
Checks that the api key in the Authorization header is valid and active
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CheckApiKeyResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.check_api_key()
"""
_response = self._raw_client.check_api_key(request_options=request_options)
return _response.data
@property
def v2(self):
if self._v2 is None:
from .v2.client import V2Client # noqa: E402
self._v2 = V2Client(client_wrapper=self._client_wrapper)
return self._v2
@property
def batches(self):
if self._batches is None:
from .batches.client import BatchesClient # noqa: E402
self._batches = BatchesClient(client_wrapper=self._client_wrapper)
return self._batches
@property
def embed_jobs(self):
if self._embed_jobs is None:
from .embed_jobs.client import EmbedJobsClient # noqa: E402
self._embed_jobs = EmbedJobsClient(client_wrapper=self._client_wrapper)
return self._embed_jobs
@property
def datasets(self):
if self._datasets is None:
from .datasets.client import DatasetsClient # noqa: E402
self._datasets = DatasetsClient(client_wrapper=self._client_wrapper)
return self._datasets
@property
def connectors(self):
if self._connectors is None:
from .connectors.client import ConnectorsClient # noqa: E402
self._connectors = ConnectorsClient(client_wrapper=self._client_wrapper)
return self._connectors
@property
def models(self):
if self._models is None:
from .models.client import ModelsClient # noqa: E402
self._models = ModelsClient(client_wrapper=self._client_wrapper)
return self._models
@property
def finetuning(self):
if self._finetuning is None:
from .finetuning.client import FinetuningClient # noqa: E402
self._finetuning = FinetuningClient(client_wrapper=self._client_wrapper)
return self._finetuning
@property
def audio(self):
if self._audio is None:
from .audio.client import AudioClient # noqa: E402
self._audio = AudioClient(client_wrapper=self._client_wrapper)
return self._audio
def _make_default_async_client(
timeout: typing.Optional[float],
follow_redirects: typing.Optional[bool],
) -> httpx.AsyncClient:
try:
import httpx_aiohttp # type: ignore[import-not-found]
except ImportError:
pass
else:
if follow_redirects is not None:
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout, follow_redirects=follow_redirects)
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout)
if follow_redirects is not None:
return httpx.AsyncClient(timeout=timeout, follow_redirects=follow_redirects)
return httpx.AsyncClient(timeout=timeout)
class AsyncBaseCohere:
"""
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions.
Parameters
----------
base_url : typing.Optional[str]
The base url to use for requests from the client.
environment : ClientEnvironment
The environment to use for requests from the client. from .environment import ClientEnvironment
Defaults to ClientEnvironment.PRODUCTION
client_name : typing.Optional[str]
token : typing.Optional[typing.Union[str, typing.Callable[[], str]]]
headers : typing.Optional[typing.Dict[str, str]]
Additional headers to send with every request.
async_token : typing.Optional[typing.Callable[[], typing.Awaitable[str]]]
An async callable that returns a bearer token. Use this when token acquisition involves async I/O (e.g., refreshing tokens via an async HTTP client). When provided, this is used instead of the synchronous token for async requests.
timeout : typing.Optional[float]
The timeout to be used, in seconds, for requests. By default the timeout is 300 seconds, unless a custom httpx client is used, in which case this default is not enforced.
follow_redirects : typing.Optional[bool]
Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in.
httpx_client : typing.Optional[httpx.AsyncClient]
The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration.
logging : typing.Optional[typing.Union[LogConfig, Logger]]
Configure logging for the SDK. Accepts a LogConfig dict with 'level' (debug/info/warn/error), 'logger' (custom logger implementation), and 'silent' (boolean, defaults to True) fields. You can also pass a pre-configured Logger instance.
Examples
--------
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
"""
def __init__(
self,
*,
base_url: typing.Optional[str] = None,
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
token: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = os.getenv("CO_API_KEY"),
headers: typing.Optional[typing.Dict[str, str]] = None,
async_token: typing.Optional[typing.Callable[[], typing.Awaitable[str]]] = None,
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
logging: typing.Optional[typing.Union[LogConfig, Logger]] = None,
):
_defaulted_timeout = (
timeout if timeout is not None else 300 if httpx_client is None else httpx_client.timeout.read
)
if token is None:
raise ApiError(body="The client must be instantiated be either passing in token or setting CO_API_KEY")
self._client_wrapper = AsyncClientWrapper(
base_url=_get_base_url(base_url=base_url, environment=environment),
client_name=client_name,
token=token,
headers=headers,
async_token=async_token,
httpx_client=httpx_client
if httpx_client is not None
else _make_default_async_client(timeout=_defaulted_timeout, follow_redirects=follow_redirects),
timeout=_defaulted_timeout,
logging=logging,
)
self._raw_client = AsyncRawBaseCohere(client_wrapper=self._client_wrapper)
self._v2: typing.Optional[AsyncV2Client] = None
self._batches: typing.Optional[AsyncBatchesClient] = None
self._embed_jobs: typing.Optional[AsyncEmbedJobsClient] = None
self._datasets: typing.Optional[AsyncDatasetsClient] = None
self._connectors: typing.Optional[AsyncConnectorsClient] = None
self._models: typing.Optional[AsyncModelsClient] = None
self._finetuning: typing.Optional[AsyncFinetuningClient] = None
self._audio: typing.Optional[AsyncAudioClient] = None
@property
def with_raw_response(self) -> AsyncRawBaseCohere:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawBaseCohere
"""
return self._raw_client
async def chat_stream(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[StreamedChatResponse]:
"""
Generates a streamed text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatStreamRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[StreamedChatResponse]
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
response = await client.chat_stream(
model="command-a-03-2025",
message="hello!",
)
async for chunk in response:
yield chunk
asyncio.run(main())
"""
async with self._raw_client.chat_stream(
message=message,
accepts=accepts,
model=model,
preamble=preamble,
chat_history=chat_history,
conversation_id=conversation_id,
prompt_truncation=prompt_truncation,
connectors=connectors,
search_queries_only=search_queries_only,
documents=documents,
citation_quality=citation_quality,
temperature=temperature,
max_tokens=max_tokens,
max_input_tokens=max_input_tokens,
k=k,
p=p,
seed=seed,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
raw_prompting=raw_prompting,
tools=tools,
tool_results=tool_results,
force_single_step=force_single_step,
response_format=response_format,
safety_mode=safety_mode,
request_options=request_options,
) as r:
async for _chunk in r.data:
yield _chunk
async def chat(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> NonStreamedChatResponse:
"""
Generates a text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
NonStreamedChatResponse
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.chat(
model="command-a-03-2025",
message="Tell me about LLMs",
)
asyncio.run(main())
"""
_response = await self._raw_client.chat(
message=message,
accepts=accepts,
model=model,
preamble=preamble,
chat_history=chat_history,
conversation_id=conversation_id,
prompt_truncation=prompt_truncation,
connectors=connectors,
search_queries_only=search_queries_only,
documents=documents,
citation_quality=citation_quality,
temperature=temperature,
max_tokens=max_tokens,
max_input_tokens=max_input_tokens,
k=k,
p=p,
seed=seed,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
raw_prompting=raw_prompting,
tools=tools,
tool_results=tool_results,
force_single_step=force_single_step,
response_format=response_format,
safety_mode=safety_mode,
request_options=request_options,
)
return _response.data
async def generate_stream(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[GenerateStreamedResponse]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateStreamRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[GenerateStreamedResponse]
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
response = await client.generate_stream(
prompt="Please explain to me how LLMs work",
)
async for chunk in response:
yield chunk
asyncio.run(main())
"""
async with self._raw_client.generate_stream(
prompt=prompt,
model=model,
num_generations=num_generations,
max_tokens=max_tokens,
truncate=truncate,
temperature=temperature,
seed=seed,
preset=preset,
end_sequences=end_sequences,
stop_sequences=stop_sequences,
k=k,
p=p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
return_likelihoods=return_likelihoods,
raw_prompting=raw_prompting,
request_options=request_options,
) as r:
async for _chunk in r.data:
yield _chunk
async def generate(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> Generation:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
Generation
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.generate(
prompt="Please explain to me how LLMs work",
)
asyncio.run(main())
"""
_response = await self._raw_client.generate(
prompt=prompt,
model=model,
num_generations=num_generations,
max_tokens=max_tokens,
truncate=truncate,
temperature=temperature,
seed=seed,
preset=preset,
end_sequences=end_sequences,
stop_sequences=stop_sequences,
k=k,
p=p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
return_likelihoods=return_likelihoods,
raw_prompting=raw_prompting,
request_options=request_options,
)
return _response.data
async def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> EmbedResponse:
"""
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents.
Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Images are only supported with Embed v3.0 and newer models.
model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : typing.Optional[EmbedInputType]
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.embed(
texts=["hello", "goodbye"],
model="embed-v4.0",
input_type="classification",
)
asyncio.run(main())
"""
_response = await self._raw_client.embed(
texts=texts,
images=images,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
return _response.data
async def rerank(
self,
*,
query: str,
documents: typing.Sequence[RerankRequestDocumentsItem],
model: typing.Optional[str] = OMIT,
top_n: typing.Optional[int] = OMIT,
rank_fields: typing.Optional[typing.Sequence[str]] = OMIT,
return_documents: typing.Optional[bool] = OMIT,
max_chunks_per_doc: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> RerankResponse:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
query : str
The search query
documents : typing.Sequence[RerankRequestDocumentsItem]
A list of document objects or strings to rerank.
If a document is provided the text fields is required and all other fields will be preserved in the response.
The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
rank_fields : typing.Optional[typing.Sequence[str]]
If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
return_documents : typing.Optional[bool]
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
max_chunks_per_doc : typing.Optional[int]
The maximum number of chunks to produce internally from a document
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
RerankResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.rerank(
documents=[
{
"text": "Carson City is the capital city of the American state of Nevada."
},
{
"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
},
{
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
{
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
{
"text": "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
},
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
asyncio.run(main())
"""
_response = await self._raw_client.rerank(
query=query,
documents=documents,
model=model,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
request_options=request_options,
)
return _response.data
async def classify(
self,
*,
inputs: typing.Sequence[str],
examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT,
model: typing.Optional[str] = OMIT,
preset: typing.Optional[str] = OMIT,
truncate: typing.Optional[ClassifyRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> ClassifyResponse:
"""
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference.
Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
Parameters
----------
inputs : typing.Sequence[str]
A list of up to 96 texts to be classified. Each one must be a non-empty string.
There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models).
Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
examples : typing.Optional[typing.Sequence[ClassifyExample]]
An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`.
Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
model : typing.Optional[str]
ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
preset : typing.Optional[str]
The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
truncate : typing.Optional[ClassifyRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ClassifyResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient, ClassifyExample
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.classify(
examples=[
ClassifyExample(
text="Dermatologists don't like her!",
label="Spam",
),
ClassifyExample(
text="'Hello, open to this?'",
label="Spam",
),
ClassifyExample(
text="I need help please wire me $1000 right now",
label="Spam",
),
ClassifyExample(
text="Nice to know you ;)",
label="Spam",
),
ClassifyExample(
text="Please help me?",
label="Spam",
),
ClassifyExample(
text="Your parcel will be delivered today",
label="Not spam",
),
ClassifyExample(
text="Review changes to our Terms and Conditions",
label="Not spam",
),
ClassifyExample(
text="Weekly sync notes",
label="Not spam",
),
ClassifyExample(
text="'Re: Follow up from today's meeting'",
label="Not spam",
),
ClassifyExample(
text="Pre-read for tomorrow",
label="Not spam",
),
],
inputs=["Confirm your email address", "hey i need u to send some $"],
model="YOUR-FINE-TUNED-MODEL-ID",
)
asyncio.run(main())
"""
_response = await self._raw_client.classify(
inputs=inputs,
examples=examples,
model=model,
preset=preset,
truncate=truncate,
request_options=request_options,
)
return _response.data
async def summarize(
self,
*,
text: str,
length: typing.Optional[SummarizeRequestLength] = OMIT,
format: typing.Optional[SummarizeRequestFormat] = OMIT,
model: typing.Optional[str] = OMIT,
extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT,
temperature: typing.Optional[float] = OMIT,
additional_command: typing.Optional[str] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> SummarizeResponse:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates a summary in English for a given text.
Parameters
----------
text : str
The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
length : typing.Optional[SummarizeRequestLength]
One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
format : typing.Optional[SummarizeRequestFormat]
One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
model : typing.Optional[str]
The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
extractiveness : typing.Optional[SummarizeRequestExtractiveness]
One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
temperature : typing.Optional[float]
Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
additional_command : typing.Optional[str]
A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
SummarizeResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.summarize(
text='Ice cream is a sweetened frozen food typically eaten as a snack or dessert. It may be made from milk or cream and is flavoured with a sweetener, either sugar or an alternative, and a spice, such as cocoa or vanilla, or with fruit such as strawberries or peaches. It can also be made by whisking a flavored cream base and liquid nitrogen together. Food coloring is sometimes added, in addition to stabilizers. The mixture is cooled below the freezing point of water and stirred to incorporate air spaces and to prevent detectable ice crystals from forming. The result is a smooth, semi-solid foam that is solid at very low temperatures (below 2 °C or 35 °F). It becomes more malleable as its temperature increases.\n\nThe meaning of the name "ice cream" varies from one country to another. In some countries, such as the United States, "ice cream" applies only to a specific variety, and most governments regulate the commercial use of the various terms according to the relative quantities of the main ingredients, notably the amount of cream. Products that do not meet the criteria to be called ice cream are sometimes labelled "frozen dairy dessert" instead. In other countries, such as Italy and Argentina, one word is used fo\r all variants. Analogues made from dairy alternatives, such as goat\'s or sheep\'s milk, or milk substitutes (e.g., soy, cashew, coconut, almond milk or tofu), are available for those who are lactose intolerant, allergic to dairy protein or vegan.',
)
asyncio.run(main())
"""
_response = await self._raw_client.summarize(
text=text,
length=length,
format=format,
model=model,
extractiveness=extractiveness,
temperature=temperature,
additional_command=additional_command,
request_options=request_options,
)
return _response.data
async def tokenize(
self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None
) -> TokenizeResponse:
"""
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
text : str
The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
model : str
The input will be tokenized by the tokenizer that is used by this model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
TokenizeResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.tokenize(
text="tokenize me! :D",
model="command",
)
asyncio.run(main())
"""
_response = await self._raw_client.tokenize(text=text, model=model, request_options=request_options)
return _response.data
async def detokenize(
self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None
) -> DetokenizeResponse:
"""
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
tokens : typing.Sequence[int]
The list of tokens to be detokenized.
model : str
An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DetokenizeResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.detokenize(
tokens=[10002, 2261, 2012, 8, 2792, 43],
model="command",
)
asyncio.run(main())
"""
_response = await self._raw_client.detokenize(tokens=tokens, model=model, request_options=request_options)
return _response.data
async def check_api_key(self, *, request_options: typing.Optional[RequestOptions] = None) -> CheckApiKeyResponse:
"""
Checks that the api key in the Authorization header is valid and active
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CheckApiKeyResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.check_api_key()
asyncio.run(main())
"""
_response = await self._raw_client.check_api_key(request_options=request_options)
return _response.data
@property
def v2(self):
if self._v2 is None:
from .v2.client import AsyncV2Client # noqa: E402
self._v2 = AsyncV2Client(client_wrapper=self._client_wrapper)
return self._v2
@property
def batches(self):
if self._batches is None:
from .batches.client import AsyncBatchesClient # noqa: E402
self._batches = AsyncBatchesClient(client_wrapper=self._client_wrapper)
return self._batches
@property
def embed_jobs(self):
if self._embed_jobs is None:
from .embed_jobs.client import AsyncEmbedJobsClient # noqa: E402
self._embed_jobs = AsyncEmbedJobsClient(client_wrapper=self._client_wrapper)
return self._embed_jobs
@property
def datasets(self):
if self._datasets is None:
from .datasets.client import AsyncDatasetsClient # noqa: E402
self._datasets = AsyncDatasetsClient(client_wrapper=self._client_wrapper)
return self._datasets
@property
def connectors(self):
if self._connectors is None:
from .connectors.client import AsyncConnectorsClient # noqa: E402
self._connectors = AsyncConnectorsClient(client_wrapper=self._client_wrapper)
return self._connectors
@property
def models(self):
if self._models is None:
from .models.client import AsyncModelsClient # noqa: E402
self._models = AsyncModelsClient(client_wrapper=self._client_wrapper)
return self._models
@property
def finetuning(self):
if self._finetuning is None:
from .finetuning.client import AsyncFinetuningClient # noqa: E402
self._finetuning = AsyncFinetuningClient(client_wrapper=self._client_wrapper)
return self._finetuning
@property
def audio(self):
if self._audio is None:
from .audio.client import AsyncAudioClient # noqa: E402
self._audio = AsyncAudioClient(client_wrapper=self._client_wrapper)
return self._audio
def _get_base_url(*, base_url: typing.Optional[str] = None, environment: ClientEnvironment) -> str:
if base_url is not None:
return base_url
elif environment is not None:
return environment.value
else:
raise Exception("Please pass in either base_url or environment to construct the client")
================================================
FILE: src/cohere/batches/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import (
Batch,
BatchStatus,
CancelBatchResponse,
CreateBatchResponse,
GetBatchResponse,
ListBatchesResponse,
)
_dynamic_imports: typing.Dict[str, str] = {
"Batch": ".types",
"BatchStatus": ".types",
"CancelBatchResponse": ".types",
"CreateBatchResponse": ".types",
"GetBatchResponse": ".types",
"ListBatchesResponse": ".types",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"Batch",
"BatchStatus",
"CancelBatchResponse",
"CreateBatchResponse",
"GetBatchResponse",
"ListBatchesResponse",
]
================================================
FILE: src/cohere/batches/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from .raw_client import AsyncRawBatchesClient, RawBatchesClient
from .types.batch import Batch
from .types.cancel_batch_response import CancelBatchResponse
from .types.create_batch_response import CreateBatchResponse
from .types.get_batch_response import GetBatchResponse
from .types.list_batches_response import ListBatchesResponse
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class BatchesClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawBatchesClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawBatchesClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawBatchesClient
"""
return self._raw_client
def list(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListBatchesResponse:
"""
List the batches for the current user
Parameters
----------
page_size : typing.Optional[int]
The maximum number of batches to return. The service may return fewer than
this value.
If unspecified, at most 50 batches will be returned.
The maximum value is 1000; values above 1000 will be coerced to 1000.
page_token : typing.Optional[str]
A page token, received from a previous `ListBatches` call.
Provide this to retrieve the subsequent page.
order_by : typing.Optional[str]
Batches can be ordered by creation time or last updated time.
Use `created_at` for creation time or `updated_at` for last updated time.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListBatchesResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.batches.list(
page_size=1,
page_token="page_token",
order_by="order_by",
)
"""
_response = self._raw_client.list(
page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options
)
return _response.data
def create(self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None) -> CreateBatchResponse:
"""
Creates and executes a batch from an uploaded dataset of requests
Parameters
----------
request : Batch
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateBatchResponse
A successful response.
Examples
--------
from cohere import Client
from cohere.batches import Batch
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.batches.create(
request=Batch(
name="name",
input_dataset_id="input_dataset_id",
model="model",
),
)
"""
_response = self._raw_client.create(request=request, request_options=request_options)
return _response.data
def retrieve(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetBatchResponse:
"""
Retrieves a batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetBatchResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.batches.retrieve(
id="id",
)
"""
_response = self._raw_client.retrieve(id, request_options=request_options)
return _response.data
def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> CancelBatchResponse:
"""
Cancels an in-progress batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CancelBatchResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.batches.cancel(
id="id",
)
"""
_response = self._raw_client.cancel(id, request_options=request_options)
return _response.data
class AsyncBatchesClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawBatchesClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawBatchesClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawBatchesClient
"""
return self._raw_client
async def list(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListBatchesResponse:
"""
List the batches for the current user
Parameters
----------
page_size : typing.Optional[int]
The maximum number of batches to return. The service may return fewer than
this value.
If unspecified, at most 50 batches will be returned.
The maximum value is 1000; values above 1000 will be coerced to 1000.
page_token : typing.Optional[str]
A page token, received from a previous `ListBatches` call.
Provide this to retrieve the subsequent page.
order_by : typing.Optional[str]
Batches can be ordered by creation time or last updated time.
Use `created_at` for creation time or `updated_at` for last updated time.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListBatchesResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.batches.list(
page_size=1,
page_token="page_token",
order_by="order_by",
)
asyncio.run(main())
"""
_response = await self._raw_client.list(
page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options
)
return _response.data
async def create(
self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None
) -> CreateBatchResponse:
"""
Creates and executes a batch from an uploaded dataset of requests
Parameters
----------
request : Batch
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateBatchResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
from cohere.batches import Batch
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.batches.create(
request=Batch(
name="name",
input_dataset_id="input_dataset_id",
model="model",
),
)
asyncio.run(main())
"""
_response = await self._raw_client.create(request=request, request_options=request_options)
return _response.data
async def retrieve(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetBatchResponse:
"""
Retrieves a batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetBatchResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.batches.retrieve(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.retrieve(id, request_options=request_options)
return _response.data
async def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> CancelBatchResponse:
"""
Cancels an in-progress batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CancelBatchResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.batches.cancel(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.cancel(id, request_options=request_options)
return _response.data
================================================
FILE: src/cohere/batches/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.serialization import convert_and_respect_annotation_metadata
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.internal_server_error import InternalServerError
from ..errors.not_found_error import NotFoundError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.unauthorized_error import UnauthorizedError
from .types.batch import Batch
from .types.cancel_batch_response import CancelBatchResponse
from .types.create_batch_response import CreateBatchResponse
from .types.get_batch_response import GetBatchResponse
from .types.list_batches_response import ListBatchesResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawBatchesClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def list(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListBatchesResponse]:
"""
List the batches for the current user
Parameters
----------
page_size : typing.Optional[int]
The maximum number of batches to return. The service may return fewer than
this value.
If unspecified, at most 50 batches will be returned.
The maximum value is 1000; values above 1000 will be coerced to 1000.
page_token : typing.Optional[str]
A page token, received from a previous `ListBatches` call.
Provide this to retrieve the subsequent page.
order_by : typing.Optional[str]
Batches can be ordered by creation time or last updated time.
Use `created_at` for creation time or `updated_at` for last updated time.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListBatchesResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v2/batches",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListBatchesResponse,
construct_type(
type_=ListBatchesResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def create(
self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[CreateBatchResponse]:
"""
Creates and executes a batch from an uploaded dataset of requests
Parameters
----------
request : Batch
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CreateBatchResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v2/batches",
method="POST",
json=convert_and_respect_annotation_metadata(object_=request, annotation=Batch, direction="write"),
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateBatchResponse,
construct_type(
type_=CreateBatchResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def retrieve(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[GetBatchResponse]:
"""
Retrieves a batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[GetBatchResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v2/batches/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetBatchResponse,
construct_type(
type_=GetBatchResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def cancel(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[CancelBatchResponse]:
"""
Cancels an in-progress batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CancelBatchResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v2/batches/{jsonable_encoder(id)}:cancel",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CancelBatchResponse,
construct_type(
type_=CancelBatchResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawBatchesClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def list(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListBatchesResponse]:
"""
List the batches for the current user
Parameters
----------
page_size : typing.Optional[int]
The maximum number of batches to return. The service may return fewer than
this value.
If unspecified, at most 50 batches will be returned.
The maximum value is 1000; values above 1000 will be coerced to 1000.
page_token : typing.Optional[str]
A page token, received from a previous `ListBatches` call.
Provide this to retrieve the subsequent page.
order_by : typing.Optional[str]
Batches can be ordered by creation time or last updated time.
Use `created_at` for creation time or `updated_at` for last updated time.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListBatchesResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/batches",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListBatchesResponse,
construct_type(
type_=ListBatchesResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def create(
self, *, request: Batch, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[CreateBatchResponse]:
"""
Creates and executes a batch from an uploaded dataset of requests
Parameters
----------
request : Batch
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CreateBatchResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/batches",
method="POST",
json=convert_and_respect_annotation_metadata(object_=request, annotation=Batch, direction="write"),
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateBatchResponse,
construct_type(
type_=CreateBatchResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def retrieve(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[GetBatchResponse]:
"""
Retrieves a batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[GetBatchResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v2/batches/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetBatchResponse,
construct_type(
type_=GetBatchResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def cancel(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[CancelBatchResponse]:
"""
Cancels an in-progress batch
Parameters
----------
id : str
The batch ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CancelBatchResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v2/batches/{jsonable_encoder(id)}:cancel",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CancelBatchResponse,
construct_type(
type_=CancelBatchResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/batches/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .batch import Batch
from .batch_status import BatchStatus
from .cancel_batch_response import CancelBatchResponse
from .create_batch_response import CreateBatchResponse
from .get_batch_response import GetBatchResponse
from .list_batches_response import ListBatchesResponse
_dynamic_imports: typing.Dict[str, str] = {
"Batch": ".batch",
"BatchStatus": ".batch_status",
"CancelBatchResponse": ".cancel_batch_response",
"CreateBatchResponse": ".create_batch_response",
"GetBatchResponse": ".get_batch_response",
"ListBatchesResponse": ".list_batches_response",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"Batch",
"BatchStatus",
"CancelBatchResponse",
"CreateBatchResponse",
"GetBatchResponse",
"ListBatchesResponse",
]
================================================
FILE: src/cohere/batches/types/batch.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from .batch_status import BatchStatus
class Batch(UncheckedBaseModel):
"""
This resource represents a batch job.
"""
id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. Batch ID.
"""
name: str = pydantic.Field()
"""
Batch name (e.g. `foobar`).
"""
creator_id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. User ID of the creator.
"""
org_id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. Organization ID.
"""
status: typing.Optional[BatchStatus] = pydantic.Field(default=None)
"""
read-only. Current stage in the life-cycle of the batch.
"""
created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Creation timestamp.
"""
updated_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Latest update timestamp.
"""
input_dataset_id: str = pydantic.Field()
"""
ID of the dataset the batch reads inputs from.
"""
output_dataset_id: typing.Optional[str] = None
input_tokens: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. The total number of input tokens in the batch.
"""
output_tokens: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. The total number of output tokens in the batch.
"""
model: str = pydantic.Field()
"""
The name of the model the batch uses.
"""
num_records: typing.Optional[int] = pydantic.Field(default=None)
"""
read-only. The total number of records in the batch.
"""
num_successful_records: typing.Optional[int] = pydantic.Field(default=None)
"""
read-only. The current number of successful records in the batch.
"""
num_failed_records: typing.Optional[int] = pydantic.Field(default=None)
"""
read-only. The current number of failed records in the batch.
"""
status_reason: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. More details about the reason for the status of a batch job.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/batches/types/batch_status.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
BatchStatus = typing.Union[
typing.Literal[
"BATCH_STATUS_UNSPECIFIED",
"BATCH_STATUS_QUEUED",
"BATCH_STATUS_IN_PROGRESS",
"BATCH_STATUS_CANCELING",
"BATCH_STATUS_COMPLETED",
"BATCH_STATUS_FAILED",
"BATCH_STATUS_CANCELED",
],
typing.Any,
]
================================================
FILE: src/cohere/batches/types/cancel_batch_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
CancelBatchResponse = typing.Dict[str, typing.Any]
"""
Response to a request to cancel a batch.
"""
================================================
FILE: src/cohere/batches/types/create_batch_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from .batch import Batch
class CreateBatchResponse(UncheckedBaseModel):
"""
Response to request to create a batch.
"""
batch: Batch = pydantic.Field()
"""
Information about the batch.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/batches/types/get_batch_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from .batch import Batch
class GetBatchResponse(UncheckedBaseModel):
"""
Response to a request to get a batch.
"""
batch: Batch = pydantic.Field()
"""
Information about the batch.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/batches/types/list_batches_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from .batch import Batch
class ListBatchesResponse(UncheckedBaseModel):
"""
Response to a request to list batches.
"""
batches: typing.Optional[typing.List[Batch]] = pydantic.Field(default=None)
"""
The batches that belong to the authenticated user.
"""
next_page_token: typing.Optional[str] = pydantic.Field(default=None)
"""
A token, which can be sent as `page_token` to retrieve the next page.
If this field is omitted, there are no subsequent pages.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/bedrock_client.py
================================================
import typing
from tokenizers import Tokenizer # type: ignore
from .aws_client import AwsClient, AwsClientV2
class BedrockClient(AwsClient):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClient.__init__(
self,
service="bedrock",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")
class BedrockClientV2(AwsClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="bedrock",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
================================================
FILE: src/cohere/client.py
================================================
import asyncio
import os
import typing
from concurrent.futures import ThreadPoolExecutor
from tokenizers import Tokenizer # type: ignore
import logging
import httpx
from cohere.types.detokenize_response import DetokenizeResponse
from cohere.types.tokenize_response import TokenizeResponse
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
from .config import embed_batch_size, embed_stream_batch_size
from .core import RequestOptions
from .environment import ClientEnvironment
from .manually_maintained.cache import CacheMixin
from .manually_maintained import tokenizers as local_tokenizers
from .overrides import run_overrides
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
logger = logging.getLogger(__name__)
run_overrides()
# Use NoReturn as Never type for compatibility
Never = typing.NoReturn
def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:
method = getattr(obj, method_name)
def _wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
check_fn(*args, **kwargs)
return method(*args, **kwargs)
async def _async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
# The `return await` looks redundant, but it's necessary to ensure that the return type is correct.
check_fn(*args, **kwargs)
return await method(*args, **kwargs)
wrapped = _wrapped
if asyncio.iscoroutinefunction(method):
wrapped = _async_wrapped
wrapped.__name__ = method.__name__
wrapped.__doc__ = method.__doc__
setattr(obj, method_name, wrapped)
def throw_if_stream_is_true(*args, **kwargs) -> None:
if kwargs.get("stream") is True:
raise ValueError(
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
)
def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:
"""
This method is moved. Please update usage.
"""
def fn(*args, **kwargs):
raise ValueError(
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). "
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
)
return fn
def deprecated_function(fn_name: str) -> typing.Any:
"""
This method is deprecated. Please update usage.
"""
def fn(*args, **kwargs):
raise ValueError(
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. "
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
)
return fn
# Logs a warning when a user calls a function with an experimental parameter (kwarg in our case)
# `deprecated_kwarg` is the name of the experimental parameter, which can be a dot-separated string for nested parameters
def experimental_kwarg_decorator(func, deprecated_kwarg):
# Recursive utility function to check if a kwarg is present in the kwargs.
def check_kwarg(deprecated_kwarg: str, kwargs: typing.Dict[str, typing.Any]) -> bool:
if "." in deprecated_kwarg:
key, rest = deprecated_kwarg.split(".", 1)
if key in kwargs:
return check_kwarg(rest, kwargs[key])
return deprecated_kwarg in kwargs
def _wrapped(*args, **kwargs):
if check_kwarg(deprecated_kwarg, kwargs):
logger.warning(
f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n"
"To suppress this warning, set `log_warning_experimental_features=False` when initializing the client."
)
return func(*args, **kwargs)
async def _async_wrapped(*args, **kwargs):
if check_kwarg(deprecated_kwarg, kwargs):
logger.warning(
f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n"
"To suppress this warning, set `log_warning_experimental_features=False` when initializing the client."
)
return await func(*args, **kwargs)
wrap = _wrapped
if asyncio.iscoroutinefunction(func):
wrap = _async_wrapped
wrap.__name__ = func.__name__
wrap.__doc__ = func.__doc__
return wrap
def fix_base_url(base_url: typing.Optional[str]) -> typing.Optional[str]:
if base_url is not None:
if "cohere.com" in base_url or "cohere.ai" in base_url:
return base_url.replace("/v1", "")
return base_url
return None
class Client(BaseCohere, CacheMixin):
_executor: ThreadPoolExecutor
def __init__(
self,
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.Client] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
if api_key is None:
api_key = _get_api_key_from_environment()
base_url = fix_base_url(base_url)
self._executor = thread_pool_executor
BaseCohere.__init__(
self,
base_url=base_url,
environment=environment,
client_name=client_name,
token=api_key,
timeout=timeout,
httpx_client=httpx_client,
)
validate_args(self, "chat", throw_if_stream_is_true)
if log_warning_experimental_features:
self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore
self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore
utils = SyncSdkUtils()
# support context manager until Fern upstreams
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self._client_wrapper.httpx_client.httpx_client.close()
wait = wait
def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
batching: typing.Optional[bool] = True,
) -> EmbedResponse:
# skip batching for images for now
if batching is False or images is not OMIT:
return BaseCohere.embed(
self,
texts=texts,
images=images,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
responses = [
response
for response in self._executor.map(
lambda text_batch: BaseCohere.embed(
self,
texts=text_batch,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
),
texts_batches,
)
]
return merge_embed_responses(responses)
def embed_stream(
self,
*,
texts: typing.Sequence[str],
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
batch_size: int = embed_stream_batch_size,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[typing.Any]:
"""
Memory-efficient embed that yields embeddings one batch at a time.
Processes texts in batches and yields individual StreamedEmbedding objects
as they come back, so you can write to a vector store incrementally without
holding all embeddings in memory.
Args:
texts: Texts to embed.
model: Embedding model ID.
input_type: Input type (search_document, search_query, etc.).
embedding_types: Types of embeddings to return (float, int8, etc.).
truncate: How to handle inputs longer than the max token length.
batch_size: Texts per API call. Defaults to 96 (API max).
request_options: Request-specific configuration.
Yields:
StreamedEmbedding with index, embedding, embedding_type, and text.
"""
from .manually_maintained.streaming_embed import extract_embeddings_from_response
if not texts:
return
if batch_size < 1:
raise ValueError("batch_size must be at least 1")
texts_list = list(texts)
for batch_start in range(0, len(texts_list), batch_size):
batch_texts = texts_list[batch_start : batch_start + batch_size]
response = BaseCohere.embed(
self,
texts=batch_texts,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
response_data = response.dict() if hasattr(response, "dict") else response.__dict__
yield from extract_embeddings_from_response(response_data, batch_texts, batch_start)
"""
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
"""
check_api_key: Never = deprecated_function("check_api_key")
loglikelihood: Never = deprecated_function("loglikelihood")
batch_generate: Never = deprecated_function("batch_generate")
codebook: Never = deprecated_function("codebook")
batch_tokenize: Never = deprecated_function("batch_tokenize")
batch_detokenize: Never = deprecated_function("batch_detokenize")
detect_language: Never = deprecated_function("detect_language")
generate_feedback: Never = deprecated_function("generate_feedback")
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
create_dataset: Never = moved_function("create_dataset", ".datasets.create")
get_dataset: Never = moved_function("get_dataset", ".datasets.get")
list_datasets: Never = moved_function("list_datasets", ".datasets.list")
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
_check_response: Never = deprecated_function("_check_response")
_request: Never = deprecated_function("_request")
create_cluster_job: Never = deprecated_function("create_cluster_job")
get_cluster_job: Never = deprecated_function("get_cluster_job")
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
create_custom_model: Never = deprecated_function("create_custom_model")
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
_upload_dataset: Never = deprecated_function("_upload_dataset")
_create_signed_url: Never = deprecated_function("_create_signed_url")
get_custom_model: Never = deprecated_function("get_custom_model")
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
list_custom_models: Never = deprecated_function("list_custom_models")
create_connector: Never = moved_function("create_connector", ".connectors.create")
update_connector: Never = moved_function("update_connector", ".connectors.update")
get_connector: Never = moved_function("get_connector", ".connectors.get")
list_connectors: Never = moved_function("list_connectors", ".connectors.list")
delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
def tokenize(
self,
*,
text: str,
model: str,
request_options: typing.Optional[RequestOptions] = None,
offline: bool = True,
) -> TokenizeResponse:
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
opts: RequestOptions = request_options or {} # type: ignore
if offline:
try:
tokens = local_tokenizers.local_tokenize(self, text=text, model=model)
return TokenizeResponse(tokens=tokens, token_strings=[])
except Exception:
# Fallback to calling the API.
opts["additional_headers"] = opts.get("additional_headers", {})
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
return super().tokenize(text=text, model=model, request_options=opts)
def detokenize(
self,
*,
tokens: typing.Sequence[int],
model: str,
request_options: typing.Optional[RequestOptions] = None,
offline: typing.Optional[bool] = True,
) -> DetokenizeResponse:
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
opts: RequestOptions = request_options or {} # type: ignore
if offline:
try:
text = local_tokenizers.local_detokenize(self, model=model, tokens=tokens)
return DetokenizeResponse(text=text)
except Exception:
# Fallback to calling the API.
opts["additional_headers"] = opts.get("additional_headers", {})
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
return super().detokenize(tokens=tokens, model=model, request_options=opts)
def fetch_tokenizer(self, *, model: str) -> Tokenizer:
"""
Returns a Hugging Face tokenizer from a given model name.
"""
return local_tokenizers.get_hf_tokenizer(self, model)
class AsyncClient(AsyncBaseCohere, CacheMixin):
_executor: ThreadPoolExecutor
def __init__(
self,
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
if api_key is None:
api_key = _get_api_key_from_environment()
base_url = fix_base_url(base_url)
self._executor = thread_pool_executor
AsyncBaseCohere.__init__(
self,
base_url=base_url,
environment=environment,
client_name=client_name,
token=api_key,
timeout=timeout,
httpx_client=httpx_client,
)
validate_args(self, "chat", throw_if_stream_is_true)
if log_warning_experimental_features:
self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore
self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore
utils = AsyncSdkUtils()
# support context manager until Fern upstreams
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self._client_wrapper.httpx_client.httpx_client.aclose()
wait = async_wait
async def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
batching: typing.Optional[bool] = True,
) -> EmbedResponse:
# skip batching for images for now
if batching is False or images is not OMIT:
return await AsyncBaseCohere.embed(
self,
texts=texts,
images=images,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
responses = typing.cast(
typing.List[EmbedResponse],
await asyncio.gather(
*[
AsyncBaseCohere.embed(
self,
texts=text_batch,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
for text_batch in texts_batches
]
),
)
return merge_embed_responses(responses)
"""
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
"""
check_api_key: Never = deprecated_function("check_api_key")
loglikelihood: Never = deprecated_function("loglikelihood")
batch_generate: Never = deprecated_function("batch_generate")
codebook: Never = deprecated_function("codebook")
batch_tokenize: Never = deprecated_function("batch_tokenize")
batch_detokenize: Never = deprecated_function("batch_detokenize")
detect_language: Never = deprecated_function("detect_language")
generate_feedback: Never = deprecated_function("generate_feedback")
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
create_dataset: Never = moved_function("create_dataset", ".datasets.create")
get_dataset: Never = moved_function("get_dataset", ".datasets.get")
list_datasets: Never = moved_function("list_datasets", ".datasets.list")
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
_check_response: Never = deprecated_function("_check_response")
_request: Never = deprecated_function("_request")
create_cluster_job: Never = deprecated_function("create_cluster_job")
get_cluster_job: Never = deprecated_function("get_cluster_job")
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
create_custom_model: Never = deprecated_function("create_custom_model")
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
_upload_dataset: Never = deprecated_function("_upload_dataset")
_create_signed_url: Never = deprecated_function("_create_signed_url")
get_custom_model: Never = deprecated_function("get_custom_model")
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
list_custom_models: Never = deprecated_function("list_custom_models")
create_connector: Never = moved_function("create_connector", ".connectors.create")
update_connector: Never = moved_function("update_connector", ".connectors.update")
get_connector: Never = moved_function("get_connector", ".connectors.get")
list_connectors: Never = moved_function("list_connectors", ".connectors.list")
delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
async def tokenize(
self,
*,
text: str,
model: str,
request_options: typing.Optional[RequestOptions] = None,
offline: typing.Optional[bool] = True,
) -> TokenizeResponse:
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
opts: RequestOptions = request_options or {} # type: ignore
if offline:
try:
tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text)
return TokenizeResponse(tokens=tokens, token_strings=[])
except Exception:
opts["additional_headers"] = opts.get("additional_headers", {})
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
return await super().tokenize(text=text, model=model, request_options=opts)
async def detokenize(
self,
*,
tokens: typing.Sequence[int],
model: str,
request_options: typing.Optional[RequestOptions] = None,
offline: typing.Optional[bool] = True,
) -> DetokenizeResponse:
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
opts: RequestOptions = request_options or {} # type: ignore
if offline:
try:
text = await local_tokenizers.async_local_detokenize(self, model=model, tokens=tokens)
return DetokenizeResponse(text=text)
except Exception:
opts["additional_headers"] = opts.get("additional_headers", {})
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
return await super().detokenize(tokens=tokens, model=model, request_options=opts)
async def fetch_tokenizer(self, *, model: str) -> Tokenizer:
"""
Returns a Hugging Face tokenizer from a given model name.
"""
return await local_tokenizers.async_get_hf_tokenizer(self, model)
def _get_api_key_from_environment() -> typing.Optional[str]:
"""
Retrieves the Cohere API key from specific environment variables.
CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented).
"""
return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY"))
================================================
FILE: src/cohere/client_v2.py
================================================
import os
import typing
from concurrent.futures import ThreadPoolExecutor
import httpx
from .client import AsyncClient, Client
from .environment import ClientEnvironment
from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client
class _CombinedRawClient:
"""Proxy that combines v1 and v2 raw clients.
V2Client and Client both assign to self._raw_client in __init__,
causing a collision when combined in ClientV2/AsyncClientV2.
This proxy delegates to v2 first, falling back to v1 for
legacy methods like generate_stream.
"""
def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any):
self._v1 = v1_raw_client
self._v2 = v2_raw_client
def __getattr__(self, name: str) -> typing.Any:
try:
return getattr(self._v2, name)
except AttributeError:
return getattr(self._v1, name)
class ClientV2(V2Client, Client): # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.Client] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
Client.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
V2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client))
class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
AsyncClient.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
AsyncV2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client))
================================================
FILE: src/cohere/config.py
================================================
embed_batch_size = 96
embed_stream_batch_size = 96 # Max texts per API request (API limit)
================================================
FILE: src/cohere/connectors/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
================================================
FILE: src/cohere/connectors/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from ..types.create_connector_o_auth import CreateConnectorOAuth
from ..types.create_connector_response import CreateConnectorResponse
from ..types.create_connector_service_auth import CreateConnectorServiceAuth
from ..types.delete_connector_response import DeleteConnectorResponse
from ..types.get_connector_response import GetConnectorResponse
from ..types.list_connectors_response import ListConnectorsResponse
from ..types.o_auth_authorize_response import OAuthAuthorizeResponse
from ..types.update_connector_response import UpdateConnectorResponse
from .raw_client import AsyncRawConnectorsClient, RawConnectorsClient
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class ConnectorsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawConnectorsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawConnectorsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawConnectorsClient
"""
return self._raw_client
def list(
self,
*,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListConnectorsResponse:
"""
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
limit : typing.Optional[float]
Maximum number of connectors to return [0, 100].
offset : typing.Optional[float]
Number of connectors to skip before returning results [0, inf].
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListConnectorsResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.list(
limit=1.1,
offset=1.1,
)
"""
_response = self._raw_client.list(limit=limit, offset=offset, request_options=request_options)
return _response.data
def create(
self,
*,
name: str,
url: str,
description: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> CreateConnectorResponse:
"""
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
Parameters
----------
name : str
A human-readable name for the connector.
url : str
The URL of the connector that will be used to search for documents.
description : typing.Optional[str]
A description of the connector.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
Whether the connector is active or not.
continue_on_failure : typing.Optional[bool]
Whether a chat request should continue or not if the request to this connector fails.
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateConnectorResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.create(
name="name",
url="url",
)
"""
_response = self._raw_client.create(
name=name,
url=url,
description=description,
excludes=excludes,
oauth=oauth,
active=active,
continue_on_failure=continue_on_failure,
service_auth=service_auth,
request_options=request_options,
)
return _response.data
def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetConnectorResponse:
"""
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetConnectorResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.get(
id="id",
)
"""
_response = self._raw_client.get(id, request_options=request_options)
return _response.data
def delete(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DeleteConnectorResponse:
"""
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to delete.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DeleteConnectorResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.delete(
id="id",
)
"""
_response = self._raw_client.delete(id, request_options=request_options)
return _response.data
def update(
self,
id: str,
*,
name: typing.Optional[str] = OMIT,
url: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> UpdateConnectorResponse:
"""
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
id : str
The ID of the connector to update.
name : typing.Optional[str]
A human-readable name for the connector.
url : typing.Optional[str]
The URL of the connector that will be used to search for documents.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
continue_on_failure : typing.Optional[bool]
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
UpdateConnectorResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.update(
id="id",
)
"""
_response = self._raw_client.update(
id,
name=name,
url=url,
excludes=excludes,
oauth=oauth,
active=active,
continue_on_failure=continue_on_failure,
service_auth=service_auth,
request_options=request_options,
)
return _response.data
def o_auth_authorize(
self,
id: str,
*,
after_token_redirect: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> OAuthAuthorizeResponse:
"""
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
Parameters
----------
id : str
The ID of the connector to authorize.
after_token_redirect : typing.Optional[str]
The URL to redirect to after the connector has been authorized.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
OAuthAuthorizeResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.connectors.o_auth_authorize(
id="id",
after_token_redirect="after_token_redirect",
)
"""
_response = self._raw_client.o_auth_authorize(
id, after_token_redirect=after_token_redirect, request_options=request_options
)
return _response.data
class AsyncConnectorsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawConnectorsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawConnectorsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawConnectorsClient
"""
return self._raw_client
async def list(
self,
*,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListConnectorsResponse:
"""
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
limit : typing.Optional[float]
Maximum number of connectors to return [0, 100].
offset : typing.Optional[float]
Number of connectors to skip before returning results [0, inf].
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListConnectorsResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.list(
limit=1.1,
offset=1.1,
)
asyncio.run(main())
"""
_response = await self._raw_client.list(limit=limit, offset=offset, request_options=request_options)
return _response.data
async def create(
self,
*,
name: str,
url: str,
description: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> CreateConnectorResponse:
"""
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
Parameters
----------
name : str
A human-readable name for the connector.
url : str
The URL of the connector that will be used to search for documents.
description : typing.Optional[str]
A description of the connector.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
Whether the connector is active or not.
continue_on_failure : typing.Optional[bool]
Whether a chat request should continue or not if the request to this connector fails.
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateConnectorResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.create(
name="name",
url="url",
)
asyncio.run(main())
"""
_response = await self._raw_client.create(
name=name,
url=url,
description=description,
excludes=excludes,
oauth=oauth,
active=active,
continue_on_failure=continue_on_failure,
service_auth=service_auth,
request_options=request_options,
)
return _response.data
async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetConnectorResponse:
"""
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetConnectorResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.get(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.get(id, request_options=request_options)
return _response.data
async def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> DeleteConnectorResponse:
"""
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to delete.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DeleteConnectorResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.delete(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.delete(id, request_options=request_options)
return _response.data
async def update(
self,
id: str,
*,
name: typing.Optional[str] = OMIT,
url: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> UpdateConnectorResponse:
"""
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
id : str
The ID of the connector to update.
name : typing.Optional[str]
A human-readable name for the connector.
url : typing.Optional[str]
The URL of the connector that will be used to search for documents.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
continue_on_failure : typing.Optional[bool]
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
UpdateConnectorResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.update(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.update(
id,
name=name,
url=url,
excludes=excludes,
oauth=oauth,
active=active,
continue_on_failure=continue_on_failure,
service_auth=service_auth,
request_options=request_options,
)
return _response.data
async def o_auth_authorize(
self,
id: str,
*,
after_token_redirect: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> OAuthAuthorizeResponse:
"""
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
Parameters
----------
id : str
The ID of the connector to authorize.
after_token_redirect : typing.Optional[str]
The URL to redirect to after the connector has been authorized.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
OAuthAuthorizeResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.connectors.o_auth_authorize(
id="id",
after_token_redirect="after_token_redirect",
)
asyncio.run(main())
"""
_response = await self._raw_client.o_auth_authorize(
id, after_token_redirect=after_token_redirect, request_options=request_options
)
return _response.data
================================================
FILE: src/cohere/connectors/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.serialization import convert_and_respect_annotation_metadata
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.client_closed_request_error import ClientClosedRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.gateway_timeout_error import GatewayTimeoutError
from ..errors.internal_server_error import InternalServerError
from ..errors.invalid_token_error import InvalidTokenError
from ..errors.not_found_error import NotFoundError
from ..errors.not_implemented_error import NotImplementedError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.too_many_requests_error import TooManyRequestsError
from ..errors.unauthorized_error import UnauthorizedError
from ..errors.unprocessable_entity_error import UnprocessableEntityError
from ..types.create_connector_o_auth import CreateConnectorOAuth
from ..types.create_connector_response import CreateConnectorResponse
from ..types.create_connector_service_auth import CreateConnectorServiceAuth
from ..types.delete_connector_response import DeleteConnectorResponse
from ..types.get_connector_response import GetConnectorResponse
from ..types.list_connectors_response import ListConnectorsResponse
from ..types.o_auth_authorize_response import OAuthAuthorizeResponse
from ..types.update_connector_response import UpdateConnectorResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawConnectorsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def list(
self,
*,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListConnectorsResponse]:
"""
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
limit : typing.Optional[float]
Maximum number of connectors to return [0, 100].
offset : typing.Optional[float]
Number of connectors to skip before returning results [0, inf].
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListConnectorsResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/connectors",
method="GET",
params={
"limit": limit,
"offset": offset,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListConnectorsResponse,
construct_type(
type_=ListConnectorsResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def create(
self,
*,
name: str,
url: str,
description: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[CreateConnectorResponse]:
"""
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
Parameters
----------
name : str
A human-readable name for the connector.
url : str
The URL of the connector that will be used to search for documents.
description : typing.Optional[str]
A description of the connector.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
Whether the connector is active or not.
continue_on_failure : typing.Optional[bool]
Whether a chat request should continue or not if the request to this connector fails.
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CreateConnectorResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/connectors",
method="POST",
json={
"name": name,
"description": description,
"url": url,
"excludes": excludes,
"oauth": convert_and_respect_annotation_metadata(
object_=oauth, annotation=CreateConnectorOAuth, direction="write"
),
"active": active,
"continue_on_failure": continue_on_failure,
"service_auth": convert_and_respect_annotation_metadata(
object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateConnectorResponse,
construct_type(
type_=CreateConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def get(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[GetConnectorResponse]:
"""
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[GetConnectorResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetConnectorResponse,
construct_type(
type_=GetConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[DeleteConnectorResponse]:
"""
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to delete.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DeleteConnectorResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DeleteConnectorResponse,
construct_type(
type_=DeleteConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def update(
self,
id: str,
*,
name: typing.Optional[str] = OMIT,
url: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[UpdateConnectorResponse]:
"""
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
id : str
The ID of the connector to update.
name : typing.Optional[str]
A human-readable name for the connector.
url : typing.Optional[str]
The URL of the connector that will be used to search for documents.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
continue_on_failure : typing.Optional[bool]
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[UpdateConnectorResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="PATCH",
json={
"name": name,
"url": url,
"excludes": excludes,
"oauth": convert_and_respect_annotation_metadata(
object_=oauth, annotation=CreateConnectorOAuth, direction="write"
),
"active": active,
"continue_on_failure": continue_on_failure,
"service_auth": convert_and_respect_annotation_metadata(
object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
UpdateConnectorResponse,
construct_type(
type_=UpdateConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def o_auth_authorize(
self,
id: str,
*,
after_token_redirect: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[OAuthAuthorizeResponse]:
"""
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
Parameters
----------
id : str
The ID of the connector to authorize.
after_token_redirect : typing.Optional[str]
The URL to redirect to after the connector has been authorized.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[OAuthAuthorizeResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}/oauth/authorize",
method="POST",
params={
"after_token_redirect": after_token_redirect,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
OAuthAuthorizeResponse,
construct_type(
type_=OAuthAuthorizeResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawConnectorsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def list(
self,
*,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListConnectorsResponse]:
"""
Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
limit : typing.Optional[float]
Maximum number of connectors to return [0, 100].
offset : typing.Optional[float]
Number of connectors to skip before returning results [0, inf].
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListConnectorsResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/connectors",
method="GET",
params={
"limit": limit,
"offset": offset,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListConnectorsResponse,
construct_type(
type_=ListConnectorsResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def create(
self,
*,
name: str,
url: str,
description: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[CreateConnectorResponse]:
"""
Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) for more information.
Parameters
----------
name : str
A human-readable name for the connector.
url : str
The URL of the connector that will be used to search for documents.
description : typing.Optional[str]
A description of the connector.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
Whether the connector is active or not.
continue_on_failure : typing.Optional[bool]
Whether a chat request should continue or not if the request to this connector fails.
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CreateConnectorResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/connectors",
method="POST",
json={
"name": name,
"description": description,
"url": url,
"excludes": excludes,
"oauth": convert_and_respect_annotation_metadata(
object_=oauth, annotation=CreateConnectorOAuth, direction="write"
),
"active": active,
"continue_on_failure": continue_on_failure,
"service_auth": convert_and_respect_annotation_metadata(
object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateConnectorResponse,
construct_type(
type_=CreateConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def get(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[GetConnectorResponse]:
"""
Retrieve a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[GetConnectorResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetConnectorResponse,
construct_type(
type_=GetConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[DeleteConnectorResponse]:
"""
Delete a connector by ID. See ['Connectors'](https://docs.cohere.com/docs/connectors) for more information.
Parameters
----------
id : str
The ID of the connector to delete.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DeleteConnectorResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DeleteConnectorResponse,
construct_type(
type_=DeleteConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def update(
self,
id: str,
*,
name: typing.Optional[str] = OMIT,
url: typing.Optional[str] = OMIT,
excludes: typing.Optional[typing.Sequence[str]] = OMIT,
oauth: typing.Optional[CreateConnectorOAuth] = OMIT,
active: typing.Optional[bool] = OMIT,
continue_on_failure: typing.Optional[bool] = OMIT,
service_auth: typing.Optional[CreateConnectorServiceAuth] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[UpdateConnectorResponse]:
"""
Update a connector by ID. Omitted fields will not be updated. See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information.
Parameters
----------
id : str
The ID of the connector to update.
name : typing.Optional[str]
A human-readable name for the connector.
url : typing.Optional[str]
The URL of the connector that will be used to search for documents.
excludes : typing.Optional[typing.Sequence[str]]
A list of fields to exclude from the prompt (fields remain in the document).
oauth : typing.Optional[CreateConnectorOAuth]
The OAuth 2.0 configuration for the connector. Cannot be specified if service_auth is specified.
active : typing.Optional[bool]
continue_on_failure : typing.Optional[bool]
service_auth : typing.Optional[CreateConnectorServiceAuth]
The service to service authentication configuration for the connector. Cannot be specified if oauth is specified.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[UpdateConnectorResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}",
method="PATCH",
json={
"name": name,
"url": url,
"excludes": excludes,
"oauth": convert_and_respect_annotation_metadata(
object_=oauth, annotation=CreateConnectorOAuth, direction="write"
),
"active": active,
"continue_on_failure": continue_on_failure,
"service_auth": convert_and_respect_annotation_metadata(
object_=service_auth, annotation=CreateConnectorServiceAuth, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
UpdateConnectorResponse,
construct_type(
type_=UpdateConnectorResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def o_auth_authorize(
self,
id: str,
*,
after_token_redirect: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[OAuthAuthorizeResponse]:
"""
Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information.
Parameters
----------
id : str
The ID of the connector to authorize.
after_token_redirect : typing.Optional[str]
The URL to redirect to after the connector has been authorized.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[OAuthAuthorizeResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/connectors/{jsonable_encoder(id)}/oauth/authorize",
method="POST",
params={
"after_token_redirect": after_token_redirect,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
OAuthAuthorizeResponse,
construct_type(
type_=OAuthAuthorizeResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/core/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .api_error import ApiError
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
from .datetime_utils import Rfc2822DateTime, parse_rfc2822_datetime, serialize_datetime
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
from .http_client import AsyncHttpClient, HttpClient
from .http_response import AsyncHttpResponse, HttpResponse
from .jsonable_encoder import jsonable_encoder
from .logging import ConsoleLogger, ILogger, LogConfig, LogLevel, Logger, create_logger
from .parse_error import ParsingError
from .pydantic_utilities import (
IS_PYDANTIC_V2,
UniversalBaseModel,
UniversalRootModel,
parse_obj_as,
universal_field_validator,
universal_root_validator,
update_forward_refs,
)
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict
from .request_options import RequestOptions
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
from .unchecked_base_model import UncheckedBaseModel, UnionMetadata, construct_type
_dynamic_imports: typing.Dict[str, str] = {
"ApiError": ".api_error",
"AsyncClientWrapper": ".client_wrapper",
"AsyncHttpClient": ".http_client",
"AsyncHttpResponse": ".http_response",
"BaseClientWrapper": ".client_wrapper",
"ConsoleLogger": ".logging",
"FieldMetadata": ".serialization",
"File": ".file",
"HttpClient": ".http_client",
"HttpResponse": ".http_response",
"ILogger": ".logging",
"IS_PYDANTIC_V2": ".pydantic_utilities",
"LogConfig": ".logging",
"LogLevel": ".logging",
"Logger": ".logging",
"ParsingError": ".parse_error",
"RequestOptions": ".request_options",
"Rfc2822DateTime": ".datetime_utils",
"SyncClientWrapper": ".client_wrapper",
"UncheckedBaseModel": ".unchecked_base_model",
"UnionMetadata": ".unchecked_base_model",
"UniversalBaseModel": ".pydantic_utilities",
"UniversalRootModel": ".pydantic_utilities",
"construct_type": ".unchecked_base_model",
"convert_and_respect_annotation_metadata": ".serialization",
"convert_file_dict_to_httpx_tuples": ".file",
"create_logger": ".logging",
"encode_query": ".query_encoder",
"jsonable_encoder": ".jsonable_encoder",
"parse_obj_as": ".pydantic_utilities",
"parse_rfc2822_datetime": ".datetime_utils",
"remove_none_from_dict": ".remove_none_from_dict",
"serialize_datetime": ".datetime_utils",
"universal_field_validator": ".pydantic_utilities",
"universal_root_validator": ".pydantic_utilities",
"update_forward_refs": ".pydantic_utilities",
"with_content_type": ".file",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"ApiError",
"AsyncClientWrapper",
"AsyncHttpClient",
"AsyncHttpResponse",
"BaseClientWrapper",
"ConsoleLogger",
"FieldMetadata",
"File",
"HttpClient",
"HttpResponse",
"ILogger",
"IS_PYDANTIC_V2",
"LogConfig",
"LogLevel",
"Logger",
"ParsingError",
"RequestOptions",
"Rfc2822DateTime",
"SyncClientWrapper",
"UncheckedBaseModel",
"UnionMetadata",
"UniversalBaseModel",
"UniversalRootModel",
"construct_type",
"convert_and_respect_annotation_metadata",
"convert_file_dict_to_httpx_tuples",
"create_logger",
"encode_query",
"jsonable_encoder",
"parse_obj_as",
"parse_rfc2822_datetime",
"remove_none_from_dict",
"serialize_datetime",
"universal_field_validator",
"universal_root_validator",
"update_forward_refs",
"with_content_type",
]
================================================
FILE: src/cohere/core/api_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict, Optional
class ApiError(Exception):
headers: Optional[Dict[str, str]]
status_code: Optional[int]
body: Any
def __init__(
self,
*,
headers: Optional[Dict[str, str]] = None,
status_code: Optional[int] = None,
body: Any = None,
) -> None:
self.headers = headers
self.status_code = status_code
self.body = body
def __str__(self) -> str:
return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}"
================================================
FILE: src/cohere/core/client_wrapper.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import httpx
from .http_client import AsyncHttpClient, HttpClient
from .logging import LogConfig, Logger
class BaseClientWrapper:
def __init__(
self,
*,
client_name: typing.Optional[str] = None,
token: typing.Union[str, typing.Callable[[], str]],
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
logging: typing.Optional[typing.Union[LogConfig, Logger]] = None,
):
self._client_name = client_name
self._token = token
self._headers = headers
self._base_url = base_url
self._timeout = timeout
self._logging = logging
def get_headers(self) -> typing.Dict[str, str]:
import platform
headers: typing.Dict[str, str] = {
"User-Agent": "cohere/6.1.0",
"X-Fern-Language": "Python",
"X-Fern-Runtime": f"python/{platform.python_version()}",
"X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}",
"X-Fern-SDK-Name": "cohere",
"X-Fern-SDK-Version": "6.1.0",
**(self.get_custom_headers() or {}),
}
if self._client_name is not None:
headers["X-Client-Name"] = self._client_name
headers["Authorization"] = f"Bearer {self._get_token()}"
return headers
def _get_token(self) -> str:
if isinstance(self._token, str):
return self._token
else:
return self._token()
def get_custom_headers(self) -> typing.Optional[typing.Dict[str, str]]:
return self._headers
def get_base_url(self) -> str:
return self._base_url
def get_timeout(self) -> typing.Optional[float]:
return self._timeout
class SyncClientWrapper(BaseClientWrapper):
def __init__(
self,
*,
client_name: typing.Optional[str] = None,
token: typing.Union[str, typing.Callable[[], str]],
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
logging: typing.Optional[typing.Union[LogConfig, Logger]] = None,
httpx_client: httpx.Client,
):
super().__init__(
client_name=client_name, token=token, headers=headers, base_url=base_url, timeout=timeout, logging=logging
)
self.httpx_client = HttpClient(
httpx_client=httpx_client,
base_headers=self.get_headers,
base_timeout=self.get_timeout,
base_url=self.get_base_url,
logging_config=self._logging,
)
class AsyncClientWrapper(BaseClientWrapper):
def __init__(
self,
*,
client_name: typing.Optional[str] = None,
token: typing.Union[str, typing.Callable[[], str]],
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
logging: typing.Optional[typing.Union[LogConfig, Logger]] = None,
async_token: typing.Optional[typing.Callable[[], typing.Awaitable[str]]] = None,
httpx_client: httpx.AsyncClient,
):
super().__init__(
client_name=client_name, token=token, headers=headers, base_url=base_url, timeout=timeout, logging=logging
)
self._async_token = async_token
self.httpx_client = AsyncHttpClient(
httpx_client=httpx_client,
base_headers=self.get_headers,
base_timeout=self.get_timeout,
base_url=self.get_base_url,
async_base_headers=self.async_get_headers,
logging_config=self._logging,
)
async def async_get_headers(self) -> typing.Dict[str, str]:
headers = self.get_headers()
if self._async_token is not None:
token = await self._async_token()
headers["Authorization"] = f"Bearer {token}"
return headers
================================================
FILE: src/cohere/core/datetime_utils.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
from email.utils import parsedate_to_datetime
from typing import Any
import pydantic
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
def parse_rfc2822_datetime(v: Any) -> dt.datetime:
"""
Parse an RFC 2822 datetime string (e.g., "Wed, 02 Oct 2002 13:00:00 GMT")
into a datetime object. If the value is already a datetime, return it as-is.
Falls back to ISO 8601 parsing if RFC 2822 parsing fails.
"""
if isinstance(v, dt.datetime):
return v
if isinstance(v, str):
try:
return parsedate_to_datetime(v)
except Exception:
pass
# Fallback to ISO 8601 parsing
return dt.datetime.fromisoformat(v.replace("Z", "+00:00"))
raise ValueError(f"Expected str or datetime, got {type(v)}")
class Rfc2822DateTime(dt.datetime):
"""A datetime subclass that parses RFC 2822 date strings.
On Pydantic V1, uses __get_validators__ for pre-validation.
On Pydantic V2, uses __get_pydantic_core_schema__ for BeforeValidator-style parsing.
"""
@classmethod
def __get_validators__(cls): # type: ignore[no-untyped-def]
yield parse_rfc2822_datetime
@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Any) -> Any: # type: ignore[override]
from pydantic_core import core_schema
return core_schema.no_info_before_validator_function(parse_rfc2822_datetime, core_schema.datetime_schema())
def serialize_datetime(v: dt.datetime) -> str:
"""
Serialize a datetime including timezone info.
Uses the timezone info provided if present, otherwise uses the current runtime's timezone info.
UTC datetimes end in "Z" while all other timezones are represented as offset from UTC, e.g. +05:00.
"""
def _serialize_zoned_datetime(v: dt.datetime) -> str:
if v.tzinfo is not None and v.tzinfo.tzname(None) == dt.timezone.utc.tzname(None):
# UTC is a special case where we use "Z" at the end instead of "+00:00"
return v.isoformat().replace("+00:00", "Z")
else:
# Delegate to the typical +/- offset format
return v.isoformat()
if v.tzinfo is not None:
return _serialize_zoned_datetime(v)
else:
local_tz = dt.datetime.now().astimezone().tzinfo
localized_dt = v.replace(tzinfo=local_tz)
return _serialize_zoned_datetime(localized_dt)
================================================
FILE: src/cohere/core/file.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import IO, Dict, List, Mapping, Optional, Tuple, Union, cast
# File typing inspired by the flexibility of types within the httpx library
# https://github.com/encode/httpx/blob/master/httpx/_types.py
FileContent = Union[IO[bytes], bytes, str]
File = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[
Optional[str],
FileContent,
Optional[str],
Mapping[str, str],
],
]
def convert_file_dict_to_httpx_tuples(
d: Dict[str, Union[File, List[File]]],
) -> List[Tuple[str, File]]:
"""
The format we use is a list of tuples, where the first element is the
name of the file and the second is the file object. Typically HTTPX wants
a dict, but to be able to send lists of files, you have to use the list
approach (which also works for non-lists)
https://github.com/encode/httpx/pull/1032
"""
httpx_tuples = []
for key, file_like in d.items():
if isinstance(file_like, list):
for file_like_item in file_like:
httpx_tuples.append((key, file_like_item))
else:
httpx_tuples.append((key, file_like))
return httpx_tuples
def with_content_type(*, file: File, default_content_type: str) -> File:
"""
This function resolves to the file's content type, if provided, and defaults
to the default_content_type value if not.
"""
if isinstance(file, tuple):
if len(file) == 2:
filename, content = cast(Tuple[Optional[str], FileContent], file) # type: ignore
return (filename, content, default_content_type)
elif len(file) == 3:
filename, content, file_content_type = cast(Tuple[Optional[str], FileContent, Optional[str]], file) # type: ignore
out_content_type = file_content_type or default_content_type
return (filename, content, out_content_type)
elif len(file) == 4:
filename, content, file_content_type, headers = cast( # type: ignore
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], file
)
out_content_type = file_content_type or default_content_type
return (filename, content, out_content_type, headers)
else:
raise ValueError(f"Unexpected tuple length: {len(file)}")
return (None, file, default_content_type)
================================================
FILE: src/cohere/core/force_multipart.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict
class ForceMultipartDict(Dict[str, Any]):
"""
A dictionary subclass that always evaluates to True in boolean contexts.
This is used to force multipart/form-data encoding in HTTP requests even when
the dictionary is empty, which would normally evaluate to False.
"""
def __bool__(self) -> bool:
return True
FORCE_MULTIPART = ForceMultipartDict()
================================================
FILE: src/cohere/core/http_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import asyncio
import email.utils
import re
import time
import typing
from contextlib import asynccontextmanager, contextmanager
from random import random
import httpx
from .file import File, convert_file_dict_to_httpx_tuples
from .force_multipart import FORCE_MULTIPART
from .jsonable_encoder import jsonable_encoder
from .logging import LogConfig, Logger, create_logger
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict as remove_none_from_dict
from .request_options import RequestOptions
from httpx._types import RequestFiles
INITIAL_RETRY_DELAY_SECONDS = 1.0
MAX_RETRY_DELAY_SECONDS = 60.0
JITTER_FACTOR = 0.2 # 20% random jitter
def _parse_retry_after(response_headers: httpx.Headers) -> typing.Optional[float]:
"""
This function parses the `Retry-After` header in a HTTP response and returns the number of seconds to wait.
Inspired by the urllib3 retry implementation.
"""
retry_after_ms = response_headers.get("retry-after-ms")
if retry_after_ms is not None:
try:
return int(retry_after_ms) / 1000 if retry_after_ms > 0 else 0
except Exception:
pass
retry_after = response_headers.get("retry-after")
if retry_after is None:
return None
# Attempt to parse the header as an int.
if re.match(r"^\s*[0-9]+\s*$", retry_after):
seconds = float(retry_after)
# Fallback to parsing it as a date.
else:
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
return None
if retry_date_tuple[9] is None: # Python 2
# Assume UTC if no timezone was specified
# On Python2.7, parsedate_tz returns None for a timezone offset
# instead of 0 if no timezone is given, where mktime_tz treats
# a None timezone offset as local time.
retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:]
retry_date = email.utils.mktime_tz(retry_date_tuple)
seconds = retry_date - time.time()
if seconds < 0:
seconds = 0
return seconds
def _add_positive_jitter(delay: float) -> float:
"""Add positive jitter (0-20%) to prevent thundering herd."""
jitter_multiplier = 1 + random() * JITTER_FACTOR
return delay * jitter_multiplier
def _add_symmetric_jitter(delay: float) -> float:
"""Add symmetric jitter (±10%) for exponential backoff."""
jitter_multiplier = 1 + (random() - 0.5) * JITTER_FACTOR
return delay * jitter_multiplier
def _parse_x_ratelimit_reset(response_headers: httpx.Headers) -> typing.Optional[float]:
"""
Parse the X-RateLimit-Reset header (Unix timestamp in seconds).
Returns seconds to wait, or None if header is missing/invalid.
"""
reset_time_str = response_headers.get("x-ratelimit-reset")
if reset_time_str is None:
return None
try:
reset_time = int(reset_time_str)
delay = reset_time - time.time()
if delay > 0:
return delay
except (ValueError, TypeError):
pass
return None
def _retry_timeout(response: httpx.Response, retries: int) -> float:
"""
Determine the amount of time to wait before retrying a request.
This function begins by trying to parse a retry-after header from the response, and then proceeds to use exponential backoff
with a jitter to determine the number of seconds to wait.
"""
# 1. Check Retry-After header first
retry_after = _parse_retry_after(response.headers)
if retry_after is not None and retry_after > 0:
return min(retry_after, MAX_RETRY_DELAY_SECONDS)
# 2. Check X-RateLimit-Reset header (with positive jitter)
ratelimit_reset = _parse_x_ratelimit_reset(response.headers)
if ratelimit_reset is not None:
return _add_positive_jitter(min(ratelimit_reset, MAX_RETRY_DELAY_SECONDS))
# 3. Fall back to exponential backoff (with symmetric jitter)
backoff = min(INITIAL_RETRY_DELAY_SECONDS * pow(2.0, retries), MAX_RETRY_DELAY_SECONDS)
return _add_symmetric_jitter(backoff)
def _retry_timeout_from_retries(retries: int) -> float:
"""Determine retry timeout using exponential backoff when no response is available."""
backoff = min(INITIAL_RETRY_DELAY_SECONDS * pow(2.0, retries), MAX_RETRY_DELAY_SECONDS)
return _add_symmetric_jitter(backoff)
def _should_retry(response: httpx.Response) -> bool:
retryable_400s = [429, 408, 409]
return response.status_code >= 500 or response.status_code in retryable_400s
_SENSITIVE_HEADERS = frozenset(
{
"authorization",
"www-authenticate",
"x-api-key",
"api-key",
"apikey",
"x-api-token",
"x-auth-token",
"auth-token",
"cookie",
"set-cookie",
"proxy-authorization",
"proxy-authenticate",
"x-csrf-token",
"x-xsrf-token",
"x-session-token",
"x-access-token",
}
)
def _redact_headers(headers: typing.Dict[str, str]) -> typing.Dict[str, str]:
return {k: ("[REDACTED]" if k.lower() in _SENSITIVE_HEADERS else v) for k, v in headers.items()}
def _build_url(base_url: str, path: typing.Optional[str]) -> str:
"""
Build a full URL by joining a base URL with a path.
This function correctly handles base URLs that contain path prefixes (e.g., tenant-based URLs)
by using string concatenation instead of urllib.parse.urljoin(), which would incorrectly
strip path components when the path starts with '/'.
Example:
>>> _build_url("https://cloud.example.com/org/tenant/api", "/users")
'https://cloud.example.com/org/tenant/api/users'
Args:
base_url: The base URL, which may contain path prefixes.
path: The path to append. Can be None or empty string.
Returns:
The full URL with base_url and path properly joined.
"""
if not path:
return base_url
return f"{base_url.rstrip('/')}/{path.lstrip('/')}"
def _maybe_filter_none_from_multipart_data(
data: typing.Optional[typing.Any],
request_files: typing.Optional[RequestFiles],
force_multipart: typing.Optional[bool],
) -> typing.Optional[typing.Any]:
"""
Filter None values from data body for multipart/form requests.
This prevents httpx from converting None to empty strings in multipart encoding.
Only applies when files are present or force_multipart is True.
"""
if data is not None and isinstance(data, typing.Mapping) and (request_files or force_multipart):
return remove_none_from_dict(data)
return data
def remove_omit_from_dict(
original: typing.Dict[str, typing.Optional[typing.Any]],
omit: typing.Optional[typing.Any],
) -> typing.Dict[str, typing.Any]:
if omit is None:
return original
new: typing.Dict[str, typing.Any] = {}
for key, value in original.items():
if value is not omit:
new[key] = value
return new
def maybe_filter_request_body(
data: typing.Optional[typing.Any],
request_options: typing.Optional[RequestOptions],
omit: typing.Optional[typing.Any],
) -> typing.Optional[typing.Any]:
if data is None:
return (
jsonable_encoder(request_options.get("additional_body_parameters", {})) or {}
if request_options is not None
else None
)
elif not isinstance(data, typing.Mapping):
data_content = jsonable_encoder(data)
else:
data_content = {
**(jsonable_encoder(remove_omit_from_dict(data, omit))), # type: ignore
**(
jsonable_encoder(request_options.get("additional_body_parameters", {})) or {}
if request_options is not None
else {}
),
}
return data_content
# Abstracted out for testing purposes
def get_request_body(
*,
json: typing.Optional[typing.Any],
data: typing.Optional[typing.Any],
request_options: typing.Optional[RequestOptions],
omit: typing.Optional[typing.Any],
) -> typing.Tuple[typing.Optional[typing.Any], typing.Optional[typing.Any]]:
json_body = None
data_body = None
if data is not None:
data_body = maybe_filter_request_body(data, request_options, omit)
else:
# If both data and json are None, we send json data in the event extra properties are specified
json_body = maybe_filter_request_body(json, request_options, omit)
has_additional_body_parameters = bool(
request_options is not None and request_options.get("additional_body_parameters")
)
# Only collapse empty dict to None when the body was not explicitly provided
# and there are no additional body parameters. This preserves explicit empty
# bodies (e.g., when an endpoint has a request body type but all fields are optional).
if json_body == {} and json is None and not has_additional_body_parameters:
json_body = None
if data_body == {} and data is None and not has_additional_body_parameters:
data_body = None
return json_body, data_body
class HttpClient:
def __init__(
self,
*,
httpx_client: httpx.Client,
base_timeout: typing.Callable[[], typing.Optional[float]],
base_headers: typing.Callable[[], typing.Dict[str, str]],
base_url: typing.Optional[typing.Callable[[], str]] = None,
base_max_retries: int = 2,
logging_config: typing.Optional[typing.Union[LogConfig, Logger]] = None,
):
self.base_url = base_url
self.base_timeout = base_timeout
self.base_headers = base_headers
self.base_max_retries = base_max_retries
self.httpx_client = httpx_client
self.logger = create_logger(logging_config)
def get_base_url(self, maybe_base_url: typing.Optional[str]) -> str:
base_url = maybe_base_url
if self.base_url is not None and base_url is None:
base_url = self.base_url()
if base_url is None:
raise ValueError("A base_url is required to make this request, please provide one and try again.")
return base_url
def request(
self,
path: typing.Optional[str] = None,
*,
method: str,
base_url: typing.Optional[str] = None,
params: typing.Optional[typing.Dict[str, typing.Any]] = None,
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 0,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> httpx.Response:
base_url = self.get_base_url(base_url)
timeout = (
request_options.get("timeout_in_seconds")
if request_options is not None and request_options.get("timeout_in_seconds") is not None
else self.base_timeout()
)
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart)
# Compute encoded params separately to avoid passing empty list to httpx
# (httpx strips existing query params from URL when params=[] is passed)
_encoded_params = encode_query(
jsonable_encoder(
remove_none_from_dict(
remove_omit_from_dict(
{
**(params if params is not None else {}),
**(
request_options.get("additional_query_parameters", {}) or {}
if request_options is not None
else {}
),
},
omit,
)
)
)
)
_request_url = _build_url(base_url, path)
_request_headers = jsonable_encoder(
remove_none_from_dict(
{
**self.base_headers(),
**(headers if headers is not None else {}),
**(request_options.get("additional_headers", {}) or {} if request_options is not None else {}),
}
)
)
if self.logger.is_debug():
self.logger.debug(
"Making HTTP request",
method=method,
url=_request_url,
headers=_redact_headers(_request_headers),
has_body=json_body is not None or data_body is not None,
)
max_retries: int = (
request_options.get("max_retries", self.base_max_retries)
if request_options is not None
else self.base_max_retries
)
try:
response = self.httpx_client.request(
method=method,
url=_request_url,
headers=_request_headers,
params=_encoded_params if _encoded_params else None,
json=json_body,
data=data_body,
content=content,
files=request_files,
timeout=timeout,
)
except (httpx.ConnectError, httpx.RemoteProtocolError):
if retries < max_retries:
time.sleep(_retry_timeout_from_retries(retries=retries))
return self.request(
path=path,
method=method,
base_url=base_url,
params=params,
json=json,
data=data,
content=content,
files=files,
headers=headers,
request_options=request_options,
retries=retries + 1,
omit=omit,
force_multipart=force_multipart,
)
raise
if _should_retry(response=response):
if retries < max_retries:
time.sleep(_retry_timeout(response=response, retries=retries))
return self.request(
path=path,
method=method,
base_url=base_url,
params=params,
json=json,
data=data,
content=content,
files=files,
headers=headers,
request_options=request_options,
retries=retries + 1,
omit=omit,
force_multipart=force_multipart,
)
if self.logger.is_debug():
if 200 <= response.status_code < 400:
self.logger.debug(
"HTTP request succeeded",
method=method,
url=_request_url,
status_code=response.status_code,
)
if self.logger.is_error():
if response.status_code >= 400:
self.logger.error(
"HTTP request failed with error status",
method=method,
url=_request_url,
status_code=response.status_code,
)
return response
@contextmanager
def stream(
self,
path: typing.Optional[str] = None,
*,
method: str,
base_url: typing.Optional[str] = None,
params: typing.Optional[typing.Dict[str, typing.Any]] = None,
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 0,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> typing.Iterator[httpx.Response]:
base_url = self.get_base_url(base_url)
timeout = (
request_options.get("timeout_in_seconds")
if request_options is not None and request_options.get("timeout_in_seconds") is not None
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart)
# Compute encoded params separately to avoid passing empty list to httpx
# (httpx strips existing query params from URL when params=[] is passed)
_encoded_params = encode_query(
jsonable_encoder(
remove_none_from_dict(
remove_omit_from_dict(
{
**(params if params is not None else {}),
**(
request_options.get("additional_query_parameters", {})
if request_options is not None
else {}
),
},
omit,
)
)
)
)
_request_url = _build_url(base_url, path)
_request_headers = jsonable_encoder(
remove_none_from_dict(
{
**self.base_headers(),
**(headers if headers is not None else {}),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
)
)
if self.logger.is_debug():
self.logger.debug(
"Making streaming HTTP request",
method=method,
url=_request_url,
headers=_redact_headers(_request_headers),
)
with self.httpx_client.stream(
method=method,
url=_request_url,
headers=_request_headers,
params=_encoded_params if _encoded_params else None,
json=json_body,
data=data_body,
content=content,
files=request_files,
timeout=timeout,
) as stream:
yield stream
class AsyncHttpClient:
def __init__(
self,
*,
httpx_client: httpx.AsyncClient,
base_timeout: typing.Callable[[], typing.Optional[float]],
base_headers: typing.Callable[[], typing.Dict[str, str]],
base_url: typing.Optional[typing.Callable[[], str]] = None,
base_max_retries: int = 2,
async_base_headers: typing.Optional[typing.Callable[[], typing.Awaitable[typing.Dict[str, str]]]] = None,
logging_config: typing.Optional[typing.Union[LogConfig, Logger]] = None,
):
self.base_url = base_url
self.base_timeout = base_timeout
self.base_headers = base_headers
self.base_max_retries = base_max_retries
self.async_base_headers = async_base_headers
self.httpx_client = httpx_client
self.logger = create_logger(logging_config)
async def _get_headers(self) -> typing.Dict[str, str]:
if self.async_base_headers is not None:
return await self.async_base_headers()
return self.base_headers()
def get_base_url(self, maybe_base_url: typing.Optional[str]) -> str:
base_url = maybe_base_url
if self.base_url is not None and base_url is None:
base_url = self.base_url()
if base_url is None:
raise ValueError("A base_url is required to make this request, please provide one and try again.")
return base_url
async def request(
self,
path: typing.Optional[str] = None,
*,
method: str,
base_url: typing.Optional[str] = None,
params: typing.Optional[typing.Dict[str, typing.Any]] = None,
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 0,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> httpx.Response:
base_url = self.get_base_url(base_url)
timeout = (
request_options.get("timeout_in_seconds")
if request_options is not None and request_options.get("timeout_in_seconds") is not None
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart)
# Get headers (supports async token providers)
_headers = await self._get_headers()
# Compute encoded params separately to avoid passing empty list to httpx
# (httpx strips existing query params from URL when params=[] is passed)
_encoded_params = encode_query(
jsonable_encoder(
remove_none_from_dict(
remove_omit_from_dict(
{
**(params if params is not None else {}),
**(
request_options.get("additional_query_parameters", {}) or {}
if request_options is not None
else {}
),
},
omit,
)
)
)
)
_request_url = _build_url(base_url, path)
_request_headers = jsonable_encoder(
remove_none_from_dict(
{
**_headers,
**(headers if headers is not None else {}),
**(request_options.get("additional_headers", {}) or {} if request_options is not None else {}),
}
)
)
if self.logger.is_debug():
self.logger.debug(
"Making HTTP request",
method=method,
url=_request_url,
headers=_redact_headers(_request_headers),
has_body=json_body is not None or data_body is not None,
)
max_retries: int = (
request_options.get("max_retries", self.base_max_retries)
if request_options is not None
else self.base_max_retries
)
try:
response = await self.httpx_client.request(
method=method,
url=_request_url,
headers=_request_headers,
params=_encoded_params if _encoded_params else None,
json=json_body,
data=data_body,
content=content,
files=request_files,
timeout=timeout,
)
except (httpx.ConnectError, httpx.RemoteProtocolError):
if retries < max_retries:
await asyncio.sleep(_retry_timeout_from_retries(retries=retries))
return await self.request(
path=path,
method=method,
base_url=base_url,
params=params,
json=json,
data=data,
content=content,
files=files,
headers=headers,
request_options=request_options,
retries=retries + 1,
omit=omit,
force_multipart=force_multipart,
)
raise
if _should_retry(response=response):
if retries < max_retries:
await asyncio.sleep(_retry_timeout(response=response, retries=retries))
return await self.request(
path=path,
method=method,
base_url=base_url,
params=params,
json=json,
data=data,
content=content,
files=files,
headers=headers,
request_options=request_options,
retries=retries + 1,
omit=omit,
force_multipart=force_multipart,
)
if self.logger.is_debug():
if 200 <= response.status_code < 400:
self.logger.debug(
"HTTP request succeeded",
method=method,
url=_request_url,
status_code=response.status_code,
)
if self.logger.is_error():
if response.status_code >= 400:
self.logger.error(
"HTTP request failed with error status",
method=method,
url=_request_url,
status_code=response.status_code,
)
return response
@asynccontextmanager
async def stream(
self,
path: typing.Optional[str] = None,
*,
method: str,
base_url: typing.Optional[str] = None,
params: typing.Optional[typing.Dict[str, typing.Any]] = None,
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 0,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> typing.AsyncIterator[httpx.Response]:
base_url = self.get_base_url(base_url)
timeout = (
request_options.get("timeout_in_seconds")
if request_options is not None and request_options.get("timeout_in_seconds") is not None
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
data_body = _maybe_filter_none_from_multipart_data(data_body, request_files, force_multipart)
# Get headers (supports async token providers)
_headers = await self._get_headers()
# Compute encoded params separately to avoid passing empty list to httpx
# (httpx strips existing query params from URL when params=[] is passed)
_encoded_params = encode_query(
jsonable_encoder(
remove_none_from_dict(
remove_omit_from_dict(
{
**(params if params is not None else {}),
**(
request_options.get("additional_query_parameters", {})
if request_options is not None
else {}
),
},
omit=omit,
)
)
)
)
_request_url = _build_url(base_url, path)
_request_headers = jsonable_encoder(
remove_none_from_dict(
{
**_headers,
**(headers if headers is not None else {}),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
)
)
if self.logger.is_debug():
self.logger.debug(
"Making streaming HTTP request",
method=method,
url=_request_url,
headers=_redact_headers(_request_headers),
)
async with self.httpx_client.stream(
method=method,
url=_request_url,
headers=_request_headers,
params=_encoded_params if _encoded_params else None,
json=json_body,
data=data_body,
content=content,
files=request_files,
timeout=timeout,
) as stream:
yield stream
================================================
FILE: src/cohere/core/http_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Dict, Generic, TypeVar
import httpx
# Generic to represent the underlying type of the data wrapped by the HTTP response.
T = TypeVar("T")
class BaseHttpResponse:
"""Minimalist HTTP response wrapper that exposes response headers and status code."""
_response: httpx.Response
def __init__(self, response: httpx.Response):
self._response = response
@property
def headers(self) -> Dict[str, str]:
return dict(self._response.headers)
@property
def status_code(self) -> int:
return self._response.status_code
class HttpResponse(Generic[T], BaseHttpResponse):
"""HTTP response wrapper that exposes response headers and data."""
_data: T
def __init__(self, response: httpx.Response, data: T):
super().__init__(response)
self._data = data
@property
def data(self) -> T:
return self._data
def close(self) -> None:
self._response.close()
class AsyncHttpResponse(Generic[T], BaseHttpResponse):
"""HTTP response wrapper that exposes response headers and data."""
_data: T
def __init__(self, response: httpx.Response, data: T):
super().__init__(response)
self._data = data
@property
def data(self) -> T:
return self._data
async def close(self) -> None:
await self._response.aclose()
================================================
FILE: src/cohere/core/http_sse/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from ._api import EventSource, aconnect_sse, connect_sse
from ._exceptions import SSEError
from ._models import ServerSentEvent
_dynamic_imports: typing.Dict[str, str] = {
"EventSource": "._api",
"SSEError": "._exceptions",
"ServerSentEvent": "._models",
"aconnect_sse": "._api",
"connect_sse": "._api",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["EventSource", "SSEError", "ServerSentEvent", "aconnect_sse", "connect_sse"]
================================================
FILE: src/cohere/core/http_sse/_api.py
================================================
# This file was auto-generated by Fern from our API Definition.
import re
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast
import httpx
from ._decoders import SSEDecoder
from ._exceptions import SSEError
from ._models import ServerSentEvent
class EventSource:
def __init__(self, response: httpx.Response) -> None:
self._response = response
def _check_content_type(self) -> None:
content_type = self._response.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise SSEError(
f"Expected response header Content-Type to contain 'text/event-stream', got {content_type!r}"
)
def _get_charset(self) -> str:
"""Extract charset from Content-Type header, fallback to UTF-8."""
content_type = self._response.headers.get("content-type", "")
# Parse charset parameter using regex
charset_match = re.search(r"charset=([^;\s]+)", content_type, re.IGNORECASE)
if charset_match:
charset = charset_match.group(1).strip("\"'")
# Validate that it's a known encoding
try:
# Test if the charset is valid by trying to encode/decode
"test".encode(charset).decode(charset)
return charset
except (LookupError, UnicodeError):
# If charset is invalid, fall back to UTF-8
pass
# Default to UTF-8 if no charset specified or invalid charset
return "utf-8"
@property
def response(self) -> httpx.Response:
return self._response
def iter_sse(self) -> Iterator[ServerSentEvent]:
self._check_content_type()
decoder = SSEDecoder()
charset = self._get_charset()
buffer = ""
for chunk in self._response.iter_bytes():
# Decode chunk using detected charset
text_chunk = chunk.decode(charset, errors="replace")
buffer += text_chunk
# Process complete lines
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.rstrip("\r")
sse = decoder.decode(line)
# when we reach a "\n\n" => line = ''
# => decoder will attempt to return an SSE Event
if sse is not None:
yield sse
# Process any remaining data in buffer
if buffer.strip():
line = buffer.rstrip("\r")
sse = decoder.decode(line)
if sse is not None:
yield sse
async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]:
self._check_content_type()
decoder = SSEDecoder()
lines = cast(AsyncGenerator[str, None], self._response.aiter_lines())
try:
async for line in lines:
line = line.rstrip("\n")
sse = decoder.decode(line)
if sse is not None:
yield sse
finally:
await lines.aclose()
@contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any) -> Iterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)
@asynccontextmanager
async def aconnect_sse(
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: Any,
) -> AsyncIterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
async with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)
================================================
FILE: src/cohere/core/http_sse/_decoders.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import List, Optional
from ._models import ServerSentEvent
class SSEDecoder:
def __init__(self) -> None:
self._event = ""
self._data: List[str] = []
self._last_event_id = ""
self._retry: Optional[int] = None
def decode(self, line: str) -> Optional[ServerSentEvent]:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = ""
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None
================================================
FILE: src/cohere/core/http_sse/_exceptions.py
================================================
# This file was auto-generated by Fern from our API Definition.
import httpx
class SSEError(httpx.TransportError):
pass
================================================
FILE: src/cohere/core/http_sse/_models.py
================================================
# This file was auto-generated by Fern from our API Definition.
import json
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(frozen=True)
class ServerSentEvent:
event: str = "message"
data: str = ""
id: str = ""
retry: Optional[int] = None
def json(self) -> Any:
"""Parse the data field as JSON."""
return json.loads(self.data)
================================================
FILE: src/cohere/core/jsonable_encoder.py
================================================
# This file was auto-generated by Fern from our API Definition.
"""
jsonable_encoder converts a Python object to a JSON-friendly dict
(e.g. datetimes to strings, Pydantic models to dicts).
Taken from FastAPI, and made a bit simpler
https://github.com/tiangolo/fastapi/blob/master/fastapi/encoders.py
"""
import base64
import dataclasses
import datetime as dt
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Union
import pydantic
from .datetime_utils import serialize_datetime
from .pydantic_utilities import (
IS_PYDANTIC_V2,
encode_by_type,
to_jsonable_with_fallback,
)
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any:
custom_encoder = custom_encoder or {}
# Generated SDKs use Ellipsis (`...`) as the sentinel value for "OMIT".
# OMIT values should be excluded from serialized payloads.
if obj is Ellipsis:
return None
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if isinstance(obj, pydantic.BaseModel):
if IS_PYDANTIC_V2:
encoder = getattr(obj.model_config, "json_encoders", {}) # type: ignore # Pydantic v2
else:
encoder = getattr(obj.__config__, "json_encoders", {}) # type: ignore # Pydantic v1
if custom_encoder:
encoder.update(custom_encoder)
obj_dict = obj.dict(by_alias=True)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
if "root" in obj_dict:
obj_dict = obj_dict["root"]
return jsonable_encoder(obj_dict, custom_encoder=encoder)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj) # type: ignore
return jsonable_encoder(obj_dict, custom_encoder=custom_encoder)
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dt.datetime):
return serialize_datetime(obj)
if isinstance(obj, dt.date):
return str(obj)
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
for key, value in obj.items():
if key in allowed_keys:
if value is Ellipsis:
continue
encoded_key = jsonable_encoder(key, custom_encoder=custom_encoder)
encoded_value = jsonable_encoder(value, custom_encoder=custom_encoder)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
if item is Ellipsis:
continue
encoded_list.append(jsonable_encoder(item, custom_encoder=custom_encoder))
return encoded_list
def fallback_serializer(o: Any) -> Any:
attempt_encode = encode_by_type(o)
if attempt_encode is not None:
return attempt_encode
try:
data = dict(o)
except Exception as e:
errors: List[Exception] = []
errors.append(e)
try:
data = vars(o)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(data, custom_encoder=custom_encoder)
return to_jsonable_with_fallback(obj, fallback_serializer)
================================================
FILE: src/cohere/core/logging.py
================================================
# This file was auto-generated by Fern from our API Definition.
import logging
import typing
LogLevel = typing.Literal["debug", "info", "warn", "error"]
_LOG_LEVEL_MAP: typing.Dict[LogLevel, int] = {
"debug": 1,
"info": 2,
"warn": 3,
"error": 4,
}
class ILogger(typing.Protocol):
def debug(self, message: str, **kwargs: typing.Any) -> None: ...
def info(self, message: str, **kwargs: typing.Any) -> None: ...
def warn(self, message: str, **kwargs: typing.Any) -> None: ...
def error(self, message: str, **kwargs: typing.Any) -> None: ...
class ConsoleLogger:
_logger: logging.Logger
def __init__(self) -> None:
self._logger = logging.getLogger("fern")
if not self._logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s"))
self._logger.addHandler(handler)
self._logger.setLevel(logging.DEBUG)
def debug(self, message: str, **kwargs: typing.Any) -> None:
self._logger.debug(message, extra=kwargs)
def info(self, message: str, **kwargs: typing.Any) -> None:
self._logger.info(message, extra=kwargs)
def warn(self, message: str, **kwargs: typing.Any) -> None:
self._logger.warning(message, extra=kwargs)
def error(self, message: str, **kwargs: typing.Any) -> None:
self._logger.error(message, extra=kwargs)
class LogConfig(typing.TypedDict, total=False):
level: LogLevel
logger: ILogger
silent: bool
class Logger:
_level: int
_logger: ILogger
_silent: bool
def __init__(self, *, level: LogLevel, logger: ILogger, silent: bool) -> None:
self._level = _LOG_LEVEL_MAP[level]
self._logger = logger
self._silent = silent
def _should_log(self, level: LogLevel) -> bool:
return not self._silent and self._level <= _LOG_LEVEL_MAP[level]
def is_debug(self) -> bool:
return self._should_log("debug")
def is_info(self) -> bool:
return self._should_log("info")
def is_warn(self) -> bool:
return self._should_log("warn")
def is_error(self) -> bool:
return self._should_log("error")
def debug(self, message: str, **kwargs: typing.Any) -> None:
if self.is_debug():
self._logger.debug(message, **kwargs)
def info(self, message: str, **kwargs: typing.Any) -> None:
if self.is_info():
self._logger.info(message, **kwargs)
def warn(self, message: str, **kwargs: typing.Any) -> None:
if self.is_warn():
self._logger.warn(message, **kwargs)
def error(self, message: str, **kwargs: typing.Any) -> None:
if self.is_error():
self._logger.error(message, **kwargs)
_default_logger: Logger = Logger(level="info", logger=ConsoleLogger(), silent=True)
def create_logger(config: typing.Optional[typing.Union[LogConfig, Logger]] = None) -> Logger:
if config is None:
return _default_logger
if isinstance(config, Logger):
return config
return Logger(
level=config.get("level", "info"),
logger=config.get("logger", ConsoleLogger()),
silent=config.get("silent", True),
)
================================================
FILE: src/cohere/core/parse_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict, Optional
class ParsingError(Exception):
"""
Raised when the SDK fails to parse/validate a response from the server.
This typically indicates that the server returned a response whose shape
does not match the expected schema.
"""
headers: Optional[Dict[str, str]]
status_code: Optional[int]
body: Any
cause: Optional[Exception]
def __init__(
self,
*,
headers: Optional[Dict[str, str]] = None,
status_code: Optional[int] = None,
body: Any = None,
cause: Optional[Exception] = None,
) -> None:
self.headers = headers
self.status_code = status_code
self.body = body
self.cause = cause
super().__init__()
if cause is not None:
self.__cause__ = cause
def __str__(self) -> str:
cause_str = f", cause: {self.cause}" if self.cause is not None else ""
return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}{cause_str}"
================================================
FILE: src/cohere/core/pydantic_utilities.py
================================================
# This file was auto-generated by Fern from our API Definition.
# nopycln: file
import datetime as dt
import inspect
import json
import logging
from collections import defaultdict
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import pydantic
import typing_extensions
from pydantic.fields import FieldInfo as _FieldInfo
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from .http_sse._models import ServerSentEvent
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
if IS_PYDANTIC_V2:
_datetime_adapter = pydantic.TypeAdapter(dt.datetime) # type: ignore[attr-defined]
_date_adapter = pydantic.TypeAdapter(dt.date) # type: ignore[attr-defined]
def parse_datetime(value: Any) -> dt.datetime: # type: ignore[misc]
if isinstance(value, dt.datetime):
return value
return _datetime_adapter.validate_python(value)
def parse_date(value: Any) -> dt.date: # type: ignore[misc]
if isinstance(value, dt.datetime):
return value.date()
if isinstance(value, dt.date):
return value
return _date_adapter.validate_python(value)
# Avoid importing from pydantic.v1 to maintain Python 3.14 compatibility.
from typing import get_args as get_args # type: ignore[assignment]
from typing import get_origin as get_origin # type: ignore[assignment]
def is_literal_type(tp: Optional[Type[Any]]) -> bool: # type: ignore[misc]
return typing_extensions.get_origin(tp) is typing_extensions.Literal
def is_union(tp: Optional[Type[Any]]) -> bool: # type: ignore[misc]
return tp is Union or typing_extensions.get_origin(tp) is Union # type: ignore[comparison-overlap]
# Inline encoders_by_type to avoid importing from pydantic.v1.json
import re as _re
from collections import deque as _deque
from decimal import Decimal as _Decimal
from enum import Enum as _Enum
from ipaddress import (
IPv4Address as _IPv4Address,
)
from ipaddress import (
IPv4Interface as _IPv4Interface,
)
from ipaddress import (
IPv4Network as _IPv4Network,
)
from ipaddress import (
IPv6Address as _IPv6Address,
)
from ipaddress import (
IPv6Interface as _IPv6Interface,
)
from ipaddress import (
IPv6Network as _IPv6Network,
)
from pathlib import Path as _Path
from types import GeneratorType as _GeneratorType
from uuid import UUID as _UUID
from pydantic.fields import FieldInfo as ModelField # type: ignore[no-redef, assignment]
def _decimal_encoder(dec_value: Any) -> Any:
if dec_value.as_tuple().exponent >= 0:
return int(dec_value)
return float(dec_value)
encoders_by_type: Dict[Type[Any], Callable[[Any], Any]] = { # type: ignore[no-redef]
bytes: lambda o: o.decode(),
dt.date: lambda o: o.isoformat(),
dt.datetime: lambda o: o.isoformat(),
dt.time: lambda o: o.isoformat(),
dt.timedelta: lambda td: td.total_seconds(),
_Decimal: _decimal_encoder,
_Enum: lambda o: o.value,
frozenset: list,
_deque: list,
_GeneratorType: list,
_IPv4Address: str,
_IPv4Interface: str,
_IPv4Network: str,
_IPv6Address: str,
_IPv6Interface: str,
_IPv6Network: str,
_Path: str,
_re.Pattern: lambda o: o.pattern,
set: list,
_UUID: str,
}
else:
from pydantic.datetime_parse import parse_date as parse_date # type: ignore[no-redef]
from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore[no-redef]
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined, no-redef, assignment]
from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[no-redef]
from pydantic.typing import get_args as get_args # type: ignore[no-redef]
from pydantic.typing import get_origin as get_origin # type: ignore[no-redef]
from pydantic.typing import is_literal_type as is_literal_type # type: ignore[no-redef, assignment]
from pydantic.typing import is_union as is_union # type: ignore[no-redef]
from .datetime_utils import serialize_datetime
from .serialization import convert_and_respect_annotation_metadata
from typing_extensions import TypeAlias
T = TypeVar("T")
Model = TypeVar("Model", bound=pydantic.BaseModel)
def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
"""
Extract the discriminator field name and union variants from a discriminated union type.
Supports Annotated[Union[...], Field(discriminator=...)] patterns.
Returns (discriminator, variants) or (None, None) if not a discriminated union.
"""
origin = typing_extensions.get_origin(type_)
if origin is typing_extensions.Annotated:
args = typing_extensions.get_args(type_)
if len(args) >= 2:
inner_type = args[0]
# Check annotations for discriminator
discriminator = None
for annotation in args[1:]:
if hasattr(annotation, "discriminator"):
discriminator = getattr(annotation, "discriminator", None)
break
if discriminator:
inner_origin = typing_extensions.get_origin(inner_type)
if inner_origin is Union:
variants = list(typing_extensions.get_args(inner_type))
return discriminator, variants
return None, None
def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]:
"""Get the type annotation of a field from a Pydantic model."""
if IS_PYDANTIC_V2:
fields = getattr(model, "model_fields", {})
field_info = fields.get(field_name)
if field_info:
return cast(Optional[Type[Any]], field_info.annotation)
else:
fields = getattr(model, "__fields__", {})
field_info = fields.get(field_name)
if field_info:
return cast(Optional[Type[Any]], field_info.outer_type_)
return None
def _find_variant_by_discriminator(
variants: List[Type[Any]],
discriminator: str,
discriminator_value: Any,
) -> Optional[Type[Any]]:
"""Find the union variant that matches the discriminator value."""
for variant in variants:
if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)):
continue
disc_annotation = _get_field_annotation(variant, discriminator)
if disc_annotation and is_literal_type(disc_annotation):
literal_args = get_args(disc_annotation)
if literal_args and literal_args[0] == discriminator_value:
return variant
return None
def _is_string_type(type_: Type[Any]) -> bool:
"""Check if a type is str or Optional[str]."""
if type_ is str:
return True
origin = typing_extensions.get_origin(type_)
if origin is Union:
args = typing_extensions.get_args(type_)
# Optional[str] = Union[str, None]
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1 and non_none_args[0] is str:
return True
return False
def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.
Handles two scenarios based on where the discriminator field is located:
1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload.
The union describes the data content, not the SSE envelope.
-> Returns: json.loads(data) parsed into the type
Example: ChatStreamResponse with discriminator='type'
Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="")
Output: ContentDeltaEvent (parsed from data, SSE envelope stripped)
2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level.
The union describes the full SSE event structure.
-> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string
Example: JobStreamResponse with discriminator='event'
Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123")
Output: JobStreamResponse_Error with data as ErrorData object
But for variants where data is str (like STATUS_UPDATE):
Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1")
Output: JobStreamResponse_StatusUpdate with data as string (not parsed)
Args:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type
Returns:
The parsed object of type T
Note:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)
if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
data_value = sse_event.get("data")
# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)
if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)
else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
def parse_obj_as(type_: Type[T], object_: Any) -> T:
# convert_and_respect_annotation_metadata is required for TypedDict aliasing.
#
# For Pydantic models, whether we should pre-dealias depends on how the model encodes aliasing:
# - If the model uses real Pydantic aliases (pydantic.Field(alias=...)), then we must pass wire keys through
# unchanged so Pydantic can validate them.
# - If the model encodes aliasing only via FieldMetadata annotations, then we MUST pre-dealias because Pydantic
# will not recognize those aliases during validation.
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
has_pydantic_aliases = False
if IS_PYDANTIC_V2:
for field_name, field_info in getattr(type_, "model_fields", {}).items(): # type: ignore[attr-defined]
alias = getattr(field_info, "alias", None)
if alias is not None and alias != field_name:
has_pydantic_aliases = True
break
else:
for field in getattr(type_, "__fields__", {}).values():
alias = getattr(field, "alias", None)
name = getattr(field, "name", None)
if alias is not None and name is not None and alias != name:
has_pydantic_aliases = True
break
dealiased_object = (
object_
if has_pydantic_aliases
else convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read")
)
else:
dealiased_object = convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read")
if IS_PYDANTIC_V2:
adapter = pydantic.TypeAdapter(type_) # type: ignore[attr-defined]
return adapter.validate_python(dealiased_object)
return pydantic.parse_obj_as(type_, dealiased_object)
def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any]) -> Any:
if IS_PYDANTIC_V2:
from pydantic_core import to_jsonable_python
return to_jsonable_python(obj, fallback=fallback_serializer)
return fallback_serializer(obj)
class UniversalBaseModel(pydantic.BaseModel):
if IS_PYDANTIC_V2:
model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( # type: ignore[typeddict-unknown-key]
# Allow fields beginning with `model_` to be used in the model
protected_namespaces=(),
)
@pydantic.model_validator(mode="before") # type: ignore[attr-defined]
@classmethod
def _coerce_field_names_to_aliases(cls, data: Any) -> Any:
"""
Accept Python field names in input by rewriting them to their Pydantic aliases,
while avoiding silent collisions when a key could refer to multiple fields.
"""
if not isinstance(data, Mapping):
return data
fields = getattr(cls, "model_fields", {}) # type: ignore[attr-defined]
name_to_alias: Dict[str, str] = {}
alias_to_name: Dict[str, str] = {}
for name, field_info in fields.items():
alias = getattr(field_info, "alias", None) or name
name_to_alias[name] = alias
if alias != name:
alias_to_name[alias] = name
# Detect ambiguous keys: a key that is an alias for one field and a name for another.
ambiguous_keys = set(alias_to_name.keys()).intersection(set(name_to_alias.keys()))
for key in ambiguous_keys:
if key in data and name_to_alias[key] not in data:
raise ValueError(
f"Ambiguous input key '{key}': it is both a field name and an alias. "
"Provide the explicit alias key to disambiguate."
)
original_keys = set(data.keys())
rewritten: Dict[str, Any] = dict(data)
for name, alias in name_to_alias.items():
if alias != name and name in original_keys and alias not in rewritten:
rewritten[alias] = rewritten.pop(name)
return rewritten
@pydantic.model_serializer(mode="plain", when_used="json") # type: ignore[attr-defined]
def serialize_model(self) -> Any: # type: ignore[name-defined]
serialized = self.dict() # type: ignore[attr-defined]
data = {k: serialize_datetime(v) if isinstance(v, dt.datetime) else v for k, v in serialized.items()}
return data
else:
class Config:
smart_union = True
json_encoders = {dt.datetime: serialize_datetime}
@pydantic.root_validator(pre=True)
def _coerce_field_names_to_aliases(cls, values: Any) -> Any:
"""
Pydantic v1 equivalent of _coerce_field_names_to_aliases.
"""
if not isinstance(values, Mapping):
return values
fields = getattr(cls, "__fields__", {})
name_to_alias: Dict[str, str] = {}
alias_to_name: Dict[str, str] = {}
for name, field in fields.items():
alias = getattr(field, "alias", None) or name
name_to_alias[name] = alias
if alias != name:
alias_to_name[alias] = name
ambiguous_keys = set(alias_to_name.keys()).intersection(set(name_to_alias.keys()))
for key in ambiguous_keys:
if key in values and name_to_alias[key] not in values:
raise ValueError(
f"Ambiguous input key '{key}': it is both a field name and an alias. "
"Provide the explicit alias key to disambiguate."
)
original_keys = set(values.keys())
rewritten: Dict[str, Any] = dict(values)
for name, alias in name_to_alias.items():
if alias != name and name in original_keys and alias not in rewritten:
rewritten[alias] = rewritten.pop(name)
return rewritten
@classmethod
def model_construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
return cls.construct(_fields_set, **dealiased_object)
@classmethod
def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
if IS_PYDANTIC_V2:
return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc]
return super().construct(_fields_set, **dealiased_object)
def json(self, **kwargs: Any) -> str:
kwargs_with_defaults = {
"by_alias": True,
"exclude_unset": True,
**kwargs,
}
if IS_PYDANTIC_V2:
return super().model_dump_json(**kwargs_with_defaults) # type: ignore[misc]
return super().json(**kwargs_with_defaults)
def dict(self, **kwargs: Any) -> Dict[str, Any]:
"""
Override the default dict method to `exclude_unset` by default. This function patches
`exclude_unset` to work include fields within non-None default values.
"""
# Note: the logic here is multiplexed given the levers exposed in Pydantic V1 vs V2
# Pydantic V1's .dict can be extremely slow, so we do not want to call it twice.
#
# We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
# that we have less control over, and this is less intrusive than custom serializers for now.
if IS_PYDANTIC_V2:
kwargs_with_defaults_exclude_unset = {
**kwargs,
"by_alias": True,
"exclude_unset": True,
"exclude_none": False,
}
kwargs_with_defaults_exclude_none = {
**kwargs,
"by_alias": True,
"exclude_none": True,
"exclude_unset": False,
}
dict_dump = deep_union_pydantic_dicts(
super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore[misc]
super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore[misc]
)
else:
_fields_set = self.__fields_set__.copy()
fields = _get_model_fields(self.__class__)
for name, field in fields.items():
if name not in _fields_set:
default = _get_field_default(field)
# If the default values are non-null act like they've been set
# This effectively allows exclude_unset to work like exclude_none where
# the latter passes through intentionally set none values.
if default is not None or ("exclude_unset" in kwargs and not kwargs["exclude_unset"]):
_fields_set.add(name)
if default is not None:
self.__fields_set__.add(name)
kwargs_with_defaults_exclude_unset_include_fields = {
"by_alias": True,
"exclude_unset": True,
"include": _fields_set,
**kwargs,
}
dict_dump = super().dict(**kwargs_with_defaults_exclude_unset_include_fields)
return cast(
Dict[str, Any],
convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write"),
)
def _union_list_of_pydantic_dicts(source: List[Any], destination: List[Any]) -> List[Any]:
converted_list: List[Any] = []
for i, item in enumerate(source):
destination_value = destination[i]
if isinstance(item, dict):
converted_list.append(deep_union_pydantic_dicts(item, destination_value))
elif isinstance(item, list):
converted_list.append(_union_list_of_pydantic_dicts(item, destination_value))
else:
converted_list.append(item)
return converted_list
def deep_union_pydantic_dicts(source: Dict[str, Any], destination: Dict[str, Any]) -> Dict[str, Any]:
for key, value in source.items():
node = destination.setdefault(key, {})
if isinstance(value, dict):
deep_union_pydantic_dicts(value, node)
# Note: we do not do this same processing for sets given we do not have sets of models
# and given the sets are unordered, the processing of the set and matching objects would
# be non-trivial.
elif isinstance(value, list):
destination[key] = _union_list_of_pydantic_dicts(value, node)
else:
destination[key] = value
return destination
if IS_PYDANTIC_V2:
class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore[misc, name-defined, type-arg]
pass
UniversalRootModel: TypeAlias = V2RootModel # type: ignore[misc]
else:
UniversalRootModel: TypeAlias = UniversalBaseModel # type: ignore[misc, no-redef]
def encode_by_type(o: Any) -> Any:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple)
for type_, encoder in encoders_by_type.items():
encoders_by_class_tuples[encoder] += (type_,)
if type(o) in encoders_by_type:
return encoders_by_type[type(o)](o)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(o, classes_tuple):
return encoder(o)
def update_forward_refs(model: Type["Model"], **localns: Any) -> None:
if IS_PYDANTIC_V2:
model.model_rebuild(raise_errors=False) # type: ignore[attr-defined]
else:
model.update_forward_refs(**localns)
# Mirrors Pydantic's internal typing
AnyCallable = Callable[..., Any]
def universal_root_validator(
pre: bool = False,
) -> Callable[[AnyCallable], AnyCallable]:
def decorator(func: AnyCallable) -> AnyCallable:
if IS_PYDANTIC_V2:
# In Pydantic v2, for RootModel we always use "before" mode
# The custom validators transform the input value before the model is created
return cast(AnyCallable, pydantic.model_validator(mode="before")(func)) # type: ignore[attr-defined]
return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload]
return decorator
def universal_field_validator(field_name: str, pre: bool = False) -> Callable[[AnyCallable], AnyCallable]:
def decorator(func: AnyCallable) -> AnyCallable:
if IS_PYDANTIC_V2:
return cast(AnyCallable, pydantic.field_validator(field_name, mode="before" if pre else "after")(func)) # type: ignore[attr-defined]
return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func))
return decorator
PydanticField = Union[ModelField, _FieldInfo]
def _get_model_fields(model: Type["Model"]) -> Mapping[str, PydanticField]:
if IS_PYDANTIC_V2:
return cast(Mapping[str, PydanticField], model.model_fields) # type: ignore[attr-defined]
return cast(Mapping[str, PydanticField], model.__fields__)
def _get_field_default(field: PydanticField) -> Any:
try:
value = field.get_default() # type: ignore[union-attr]
except:
value = field.default
if IS_PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
================================================
FILE: src/cohere/core/query_encoder.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict, List, Optional, Tuple
import pydantic
# Flattens dicts to be of the form {"key[subkey][subkey2]": value} where value is not a dict
def traverse_query_dict(dict_flat: Dict[str, Any], key_prefix: Optional[str] = None) -> List[Tuple[str, Any]]:
result = []
for k, v in dict_flat.items():
key = f"{key_prefix}[{k}]" if key_prefix is not None else k
if isinstance(v, dict):
result.extend(traverse_query_dict(v, key))
elif isinstance(v, list):
for arr_v in v:
if isinstance(arr_v, dict):
result.extend(traverse_query_dict(arr_v, key))
else:
result.append((key, arr_v))
else:
result.append((key, v))
return result
def single_query_encoder(query_key: str, query_value: Any) -> List[Tuple[str, Any]]:
if isinstance(query_value, pydantic.BaseModel) or isinstance(query_value, dict):
if isinstance(query_value, pydantic.BaseModel):
obj_dict = query_value.dict(by_alias=True)
else:
obj_dict = query_value
return traverse_query_dict(obj_dict, query_key)
elif isinstance(query_value, list):
encoded_values: List[Tuple[str, Any]] = []
for value in query_value:
if isinstance(value, pydantic.BaseModel) or isinstance(value, dict):
if isinstance(value, pydantic.BaseModel):
obj_dict = value.dict(by_alias=True)
elif isinstance(value, dict):
obj_dict = value
encoded_values.extend(single_query_encoder(query_key, obj_dict))
else:
encoded_values.append((query_key, value))
return encoded_values
return [(query_key, query_value)]
def encode_query(query: Optional[Dict[str, Any]]) -> Optional[List[Tuple[str, Any]]]:
if query is None:
return None
encoded_query = []
for k, v in query.items():
encoded_query.extend(single_query_encoder(k, v))
return encoded_query
================================================
FILE: src/cohere/core/remove_none_from_dict.py
================================================
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict, Mapping, Optional
def remove_none_from_dict(original: Mapping[str, Optional[Any]]) -> Dict[str, Any]:
new: Dict[str, Any] = {}
for key, value in original.items():
if value is not None:
new[key] = value
return new
================================================
FILE: src/cohere/core/request_options.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
try:
from typing import NotRequired # type: ignore
except ImportError:
from typing_extensions import NotRequired
class RequestOptions(typing.TypedDict, total=False):
"""
Additional options for request-specific configuration when calling APIs via the SDK.
This is used primarily as an optional final parameter for service functions.
Attributes:
- timeout_in_seconds: int. The number of seconds to await an API call before timing out.
- max_retries: int. The max number of retries to attempt if the API call fails.
- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict
- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict
- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict
- chunk_size: int. The size, in bytes, to process each chunk of data being streamed back within the response. This equates to leveraging `chunk_size` within `requests` or `httpx`, and is only leveraged for file downloads.
"""
timeout_in_seconds: NotRequired[int]
max_retries: NotRequired[int]
additional_headers: NotRequired[typing.Dict[str, typing.Any]]
additional_query_parameters: NotRequired[typing.Dict[str, typing.Any]]
additional_body_parameters: NotRequired[typing.Dict[str, typing.Any]]
chunk_size: NotRequired[int]
================================================
FILE: src/cohere/core/serialization.py
================================================
# This file was auto-generated by Fern from our API Definition.
import collections
import inspect
import typing
import pydantic
import typing_extensions
class FieldMetadata:
"""
Metadata class used to annotate fields to provide additional information.
Example:
class MyDict(TypedDict):
field: typing.Annotated[str, FieldMetadata(alias="field_name")]
Will serialize: `{"field": "value"}`
To: `{"field_name": "value"}`
"""
alias: str
def __init__(self, *, alias: str) -> None:
self.alias = alias
def convert_and_respect_annotation_metadata(
*,
object_: typing.Any,
annotation: typing.Any,
inner_type: typing.Optional[typing.Any] = None,
direction: typing.Literal["read", "write"],
) -> typing.Any:
"""
Respect the metadata annotations on a field, such as aliasing. This function effectively
manipulates the dict-form of an object to respect the metadata annotations. This is primarily used for
TypedDicts, which cannot support aliasing out of the box, and can be extended for additional
utilities, such as defaults.
Parameters
----------
object_ : typing.Any
annotation : type
The type we're looking to apply typing annotations from
inner_type : typing.Optional[type]
Returns
-------
typing.Any
"""
if object_ is None:
return None
if inner_type is None:
inner_type = annotation
clean_type = _remove_annotations(inner_type)
# Pydantic models
if (
inspect.isclass(clean_type)
and issubclass(clean_type, pydantic.BaseModel)
and isinstance(object_, typing.Mapping)
):
return _convert_mapping(object_, clean_type, direction)
# TypedDicts
if typing_extensions.is_typeddict(clean_type) and isinstance(object_, typing.Mapping):
return _convert_mapping(object_, clean_type, direction)
if (
typing_extensions.get_origin(clean_type) == typing.Dict
or typing_extensions.get_origin(clean_type) == dict
or clean_type == typing.Dict
) and isinstance(object_, typing.Dict):
key_type = typing_extensions.get_args(clean_type)[0]
value_type = typing_extensions.get_args(clean_type)[1]
return {
key: convert_and_respect_annotation_metadata(
object_=value,
annotation=annotation,
inner_type=value_type,
direction=direction,
)
for key, value in object_.items()
}
# If you're iterating on a string, do not bother to coerce it to a sequence.
if not isinstance(object_, str):
if (
typing_extensions.get_origin(clean_type) == typing.Set
or typing_extensions.get_origin(clean_type) == set
or clean_type == typing.Set
) and isinstance(object_, typing.Set):
inner_type = typing_extensions.get_args(clean_type)[0]
return {
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type,
direction=direction,
)
for item in object_
}
elif (
(
typing_extensions.get_origin(clean_type) == typing.List
or typing_extensions.get_origin(clean_type) == list
or clean_type == typing.List
)
and isinstance(object_, typing.List)
) or (
(
typing_extensions.get_origin(clean_type) == typing.Sequence
or typing_extensions.get_origin(clean_type) == collections.abc.Sequence
or clean_type == typing.Sequence
)
and isinstance(object_, typing.Sequence)
):
inner_type = typing_extensions.get_args(clean_type)[0]
return [
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type,
direction=direction,
)
for item in object_
]
if typing_extensions.get_origin(clean_type) == typing.Union:
# We should be able to ~relatively~ safely try to convert keys against all
# member types in the union, the edge case here is if one member aliases a field
# of the same name to a different name from another member
# Or if another member aliases a field of the same name that another member does not.
for member in typing_extensions.get_args(clean_type):
object_ = convert_and_respect_annotation_metadata(
object_=object_,
annotation=annotation,
inner_type=member,
direction=direction,
)
return object_
annotated_type = _get_annotation(annotation)
if annotated_type is None:
return object_
# If the object is not a TypedDict, a Union, or other container (list, set, sequence, etc.)
# Then we can safely call it on the recursive conversion.
return object_
def _convert_mapping(
object_: typing.Mapping[str, object],
expected_type: typing.Any,
direction: typing.Literal["read", "write"],
) -> typing.Mapping[str, object]:
converted_object: typing.Dict[str, object] = {}
try:
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
except NameError:
# The TypedDict contains a circular reference, so
# we use the __annotations__ attribute directly.
annotations = getattr(expected_type, "__annotations__", {})
aliases_to_field_names = _get_alias_to_field_name(annotations)
for key, value in object_.items():
if direction == "read" and key in aliases_to_field_names:
dealiased_key = aliases_to_field_names.get(key)
if dealiased_key is not None:
type_ = annotations.get(dealiased_key)
else:
type_ = annotations.get(key)
# Note you can't get the annotation by the field name if you're in read mode, so you must check the aliases map
#
# So this is effectively saying if we're in write mode, and we don't have a type, or if we're in read mode and we don't have an alias
# then we can just pass the value through as is
if type_ is None:
converted_object[key] = value
elif direction == "read" and key not in aliases_to_field_names:
converted_object[key] = convert_and_respect_annotation_metadata(
object_=value, annotation=type_, direction=direction
)
else:
converted_object[_alias_key(key, type_, direction, aliases_to_field_names)] = (
convert_and_respect_annotation_metadata(object_=value, annotation=type_, direction=direction)
)
return converted_object
def _get_annotation(type_: typing.Any) -> typing.Optional[typing.Any]:
maybe_annotated_type = typing_extensions.get_origin(type_)
if maybe_annotated_type is None:
return None
if maybe_annotated_type == typing_extensions.NotRequired:
type_ = typing_extensions.get_args(type_)[0]
maybe_annotated_type = typing_extensions.get_origin(type_)
if maybe_annotated_type == typing_extensions.Annotated:
return type_
return None
def _remove_annotations(type_: typing.Any) -> typing.Any:
maybe_annotated_type = typing_extensions.get_origin(type_)
if maybe_annotated_type is None:
return type_
if maybe_annotated_type == typing_extensions.NotRequired:
return _remove_annotations(typing_extensions.get_args(type_)[0])
if maybe_annotated_type == typing_extensions.Annotated:
return _remove_annotations(typing_extensions.get_args(type_)[0])
return type_
def get_alias_to_field_mapping(type_: typing.Any) -> typing.Dict[str, str]:
annotations = typing_extensions.get_type_hints(type_, include_extras=True)
return _get_alias_to_field_name(annotations)
def get_field_to_alias_mapping(type_: typing.Any) -> typing.Dict[str, str]:
annotations = typing_extensions.get_type_hints(type_, include_extras=True)
return _get_field_to_alias_name(annotations)
def _get_alias_to_field_name(
field_to_hint: typing.Dict[str, typing.Any],
) -> typing.Dict[str, str]:
aliases = {}
for field, hint in field_to_hint.items():
maybe_alias = _get_alias_from_type(hint)
if maybe_alias is not None:
aliases[maybe_alias] = field
return aliases
def _get_field_to_alias_name(
field_to_hint: typing.Dict[str, typing.Any],
) -> typing.Dict[str, str]:
aliases = {}
for field, hint in field_to_hint.items():
maybe_alias = _get_alias_from_type(hint)
if maybe_alias is not None:
aliases[field] = maybe_alias
return aliases
def _get_alias_from_type(type_: typing.Any) -> typing.Optional[str]:
maybe_annotated_type = _get_annotation(type_)
if maybe_annotated_type is not None:
# The actual annotations are 1 onward, the first is the annotated type
annotations = typing_extensions.get_args(maybe_annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, FieldMetadata) and annotation.alias is not None:
return annotation.alias
return None
def _alias_key(
key: str,
type_: typing.Any,
direction: typing.Literal["read", "write"],
aliases_to_field_names: typing.Dict[str, str],
) -> str:
if direction == "read":
return aliases_to_field_names.get(key, key)
return _get_alias_from_type(type_=type_) or key
================================================
FILE: src/cohere/core/unchecked_base_model.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import enum
import inspect
import sys
import typing
import uuid
import pydantic
import typing_extensions
from .pydantic_utilities import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
ModelField,
UniversalBaseModel,
get_args,
get_origin,
is_literal_type,
is_union,
parse_date,
parse_datetime,
parse_obj_as,
)
from .serialization import get_field_to_alias_mapping
from pydantic_core import PydanticUndefined
class UnionMetadata:
discriminant: str
def __init__(self, *, discriminant: str) -> None:
self.discriminant = discriminant
Model = typing.TypeVar("Model", bound=pydantic.BaseModel)
def _maybe_resolve_forward_ref(
type_: typing.Any,
host: typing.Optional[typing.Type[typing.Any]],
) -> typing.Any:
"""Resolve a ForwardRef using the module where *host* is defined.
Pydantic v2 + ``from __future__ import annotations`` can leave field
annotations as ``list[ForwardRef('Block')]`` even after ``model_rebuild``.
Without resolution, ``construct_type`` sees a ForwardRef (not a class) and
skips recursive model construction, leaving nested data as raw dicts.
"""
if host is None or not isinstance(type_, typing.ForwardRef):
return type_
mod = sys.modules.get(host.__module__)
if mod is None:
return type_
try:
return eval(type_.__forward_arg__, vars(mod))
except Exception:
return type_
class UncheckedBaseModel(UniversalBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
extra = pydantic.Extra.allow
@classmethod
def model_construct(
cls: typing.Type["Model"],
_fields_set: typing.Optional[typing.Set[str]] = None,
**values: typing.Any,
) -> "Model":
# Fallback construct function to the specified override below.
return cls.construct(_fields_set=_fields_set, **values)
# Allow construct to not validate model
# Implementation taken from: https://github.com/pydantic/pydantic/issues/1168#issuecomment-817742836
@classmethod
def construct(
cls: typing.Type["Model"],
_fields_set: typing.Optional[typing.Set[str]] = None,
**values: typing.Any,
) -> "Model":
m = cls.__new__(cls)
fields_values = {}
if _fields_set is None:
_fields_set = set(values.keys())
fields = _get_model_fields(cls)
populate_by_name = _get_is_populate_by_name(cls)
field_aliases = get_field_to_alias_mapping(cls)
for name, field in fields.items():
# Key here is only used to pull data from the values dict
# you should always use the NAME of the field to for field_values, etc.
# because that's how the object is constructed from a pydantic perspective
key = field.alias
if (key is None or field.alias == name) and name in field_aliases:
key = field_aliases[name]
if key is None or (key not in values and populate_by_name): # Added this to allow population by field name
key = name
if key in values:
if IS_PYDANTIC_V2:
type_ = field.annotation # type: ignore # Pydantic v2
else:
type_ = typing.cast(typing.Type, field.outer_type_) # type: ignore # Pydantic < v1.10.15
fields_values[name] = (
construct_type(object_=values[key], type_=type_, host=cls) if type_ is not None else values[key]
)
_fields_set.add(name)
else:
default = _get_field_default(field)
fields_values[name] = default
# If the default values are non-null act like they've been set
# This effectively allows exclude_unset to work like exclude_none where
# the latter passes through intentionally set none values.
if default != None and default != PydanticUndefined:
_fields_set.add(name)
# Add extras back in
extras = {}
pydantic_alias_fields = [field.alias for field in fields.values()]
internal_alias_fields = list(field_aliases.values())
for key, value in values.items():
# If the key is not a field by name, nor an alias to a field, then it's extra
if (key not in pydantic_alias_fields and key not in internal_alias_fields) and key not in fields:
if IS_PYDANTIC_V2:
extras[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values)
if IS_PYDANTIC_V2:
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", extras)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
else:
object.__setattr__(m, "__fields_set__", _fields_set)
m._init_private_attributes() # type: ignore # Pydantic v1
return m
def _validate_collection_items_compatible(collection: typing.Any, target_type: typing.Type[typing.Any]) -> bool:
"""
Validate that all items in a collection are compatible with the target type.
Args:
collection: The collection to validate (list, set, or dict values)
target_type: The target type to validate against
Returns:
True if all items are compatible, False otherwise
"""
if inspect.isclass(target_type) and issubclass(target_type, pydantic.BaseModel):
for item in collection:
try:
# Try to validate the item against the target type
if isinstance(item, dict):
parse_obj_as(target_type, item)
else:
# If it's not a dict, it might already be the right type
if not isinstance(item, target_type):
return False
except Exception:
return False
return True
def _get_literal_field_value(
inner_type: typing.Type[typing.Any], field_name: str, field: typing.Any, object_: typing.Any
) -> typing.Any:
"""Get the value of a Literal field from *object_*, checking both alias and field name."""
name_or_alias = get_field_to_alias_mapping(inner_type).get(field_name, field_name)
pydantic_alias = getattr(field, "alias", None)
if isinstance(object_, dict):
if name_or_alias in object_:
return object_[name_or_alias]
if pydantic_alias and pydantic_alias != name_or_alias and pydantic_alias in object_:
return object_[pydantic_alias]
return None
return getattr(object_, name_or_alias, getattr(object_, pydantic_alias, None) if pydantic_alias else None)
def _literal_fields_match_strict(inner_type: typing.Type[typing.Any], object_: typing.Any) -> bool:
"""Return True iff every Literal-typed field in *inner_type* is **present** in
*object_* and its value equals the field's declared default.
This prevents models whose fields are all optional (e.g. ``FigureDetails``)
from vacuously matching inputs that don't carry the discriminant key at all
(e.g. ``{}`` for text blocks). For types with no Literal fields this
returns True unconditionally.
"""
fields = _get_model_fields(inner_type)
for field_name, field in fields.items():
if IS_PYDANTIC_V2:
field_type = field.annotation # type: ignore # Pydantic v2
else:
field_type = field.outer_type_ # type: ignore # Pydantic v1
if is_literal_type(field_type): # type: ignore[arg-type]
field_default = _get_field_default(field)
object_value = _get_literal_field_value(inner_type, field_name, field, object_)
if field_default != object_value:
return False
return True
def _convert_undiscriminated_union_type(
union_type: typing.Type[typing.Any],
object_: typing.Any,
host: typing.Optional[typing.Type[typing.Any]] = None,
) -> typing.Any:
inner_types = get_args(union_type)
if typing.Any in inner_types:
return object_
# When any union member carries a Literal discriminant field, require the
# discriminant key to be present AND matching before accepting a candidate.
# This prevents models with all-optional fields (e.g. FigureDetails) from
# greedily matching inputs that belong to a different variant or to a
# plain-dict fallback (e.g. EmptyBlockDetails = Dict[str, Any]).
has_literal_discriminant = any(
inspect.isclass(t)
and issubclass(t, pydantic.BaseModel)
and any(
is_literal_type(
f.annotation if IS_PYDANTIC_V2 else f.outer_type_ # type: ignore
)
for f in _get_model_fields(t).values()
)
for t in inner_types
)
for inner_type in inner_types:
# Handle lists of objects that need parsing
if get_origin(inner_type) is list and isinstance(object_, list):
list_inner_type = _maybe_resolve_forward_ref(get_args(inner_type)[0], host)
try:
if inspect.isclass(list_inner_type) and issubclass(list_inner_type, pydantic.BaseModel):
# Validate that all items in the list are compatible with the target type
if _validate_collection_items_compatible(object_, list_inner_type):
parsed_list = [parse_obj_as(object_=item, type_=list_inner_type) for item in object_]
return parsed_list
except Exception:
pass
try:
if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel):
if has_literal_discriminant and not _literal_fields_match_strict(inner_type, object_):
continue
# Attempt a validated parse until one works
return parse_obj_as(inner_type, object_)
except Exception:
continue
# First pass: try types where all literal fields match the object's values.
for inner_type in inner_types:
if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel):
if has_literal_discriminant:
if not _literal_fields_match_strict(inner_type, object_):
continue
else:
# Legacy lenient check: skip only when a Literal value is
# present but doesn't match (allows absent-discriminant inputs).
fields = _get_model_fields(inner_type)
literal_fields_match = True
for field_name, field in fields.items():
if IS_PYDANTIC_V2:
field_type = field.annotation # type: ignore # Pydantic v2
else:
field_type = field.outer_type_ # type: ignore # Pydantic v1
if is_literal_type(field_type): # type: ignore[arg-type]
field_default = _get_field_default(field)
object_value = _get_literal_field_value(inner_type, field_name, field, object_)
if object_value is not None and field_default != object_value:
literal_fields_match = False
break
if not literal_fields_match:
continue
try:
return construct_type(object_=object_, type_=inner_type, host=host)
except Exception:
continue
# Second pass: if no literal matches, return the first successful cast.
# When a Literal discriminant is present, skip Pydantic models whose
# discriminant doesn't match so that plain-dict fallback types are reached.
for inner_type in inner_types:
try:
if has_literal_discriminant and inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel):
if not _literal_fields_match_strict(inner_type, object_):
continue
return construct_type(object_=object_, type_=inner_type, host=host)
except Exception:
continue
def _convert_union_type(
type_: typing.Type[typing.Any],
object_: typing.Any,
host: typing.Optional[typing.Type[typing.Any]] = None,
) -> typing.Any:
base_type = get_origin(type_) or type_
union_type = type_
if base_type == typing_extensions.Annotated: # type: ignore[comparison-overlap]
union_type = get_args(type_)[0]
annotated_metadata = get_args(type_)[1:]
for metadata in annotated_metadata:
if isinstance(metadata, UnionMetadata):
try:
# Cast to the correct type, based on the discriminant
for inner_type in get_args(union_type):
try:
objects_discriminant = getattr(object_, metadata.discriminant)
except:
objects_discriminant = object_[metadata.discriminant]
if inner_type.__fields__[metadata.discriminant].default == objects_discriminant:
return construct_type(object_=object_, type_=inner_type, host=host)
except Exception:
# Allow to fall through to our regular union handling
pass
return _convert_undiscriminated_union_type(union_type, object_, host)
def construct_type(
*,
type_: typing.Type[typing.Any],
object_: typing.Any,
host: typing.Optional[typing.Type[typing.Any]] = None,
) -> typing.Any:
"""
Here we are essentially creating the same `construct` method in spirit as the above, but for all types, not just
Pydantic models.
The idea is to essentially attempt to coerce object_ to type_ (recursively)
"""
# Short circuit when dealing with optionals, don't try to coerces None to a type
if object_ is None:
return None
base_type = get_origin(type_) or type_
is_annotated = base_type == typing_extensions.Annotated # type: ignore[comparison-overlap]
maybe_annotation_members = get_args(type_)
is_annotated_union = is_annotated and is_union(get_origin(maybe_annotation_members[0]))
if base_type == typing.Any: # type: ignore[comparison-overlap]
return object_
if base_type == dict:
if not isinstance(object_, typing.Mapping):
return object_
key_type, items_type = get_args(type_)
key_type = _maybe_resolve_forward_ref(key_type, host)
items_type = _maybe_resolve_forward_ref(items_type, host)
d = {
construct_type(object_=key, type_=key_type, host=host): construct_type(
object_=item, type_=items_type, host=host
)
for key, item in object_.items()
}
return d
if base_type == list:
if not isinstance(object_, list):
return object_
inner_type = _maybe_resolve_forward_ref(get_args(type_)[0], host)
return [construct_type(object_=entry, type_=inner_type, host=host) for entry in object_]
if base_type == set:
if not isinstance(object_, set) and not isinstance(object_, list):
return object_
inner_type = _maybe_resolve_forward_ref(get_args(type_)[0], host)
return {construct_type(object_=entry, type_=inner_type, host=host) for entry in object_}
if is_union(base_type) or is_annotated_union:
return _convert_union_type(type_, object_, host)
# Cannot do an `issubclass` with a literal type, let's also just confirm we have a class before this call
if (
object_ is not None
and not is_literal_type(type_)
and (
(inspect.isclass(base_type) and issubclass(base_type, pydantic.BaseModel))
or (
is_annotated
and inspect.isclass(maybe_annotation_members[0])
and issubclass(maybe_annotation_members[0], pydantic.BaseModel)
)
)
):
if IS_PYDANTIC_V2:
return type_.model_construct(**object_)
else:
return type_.construct(**object_)
if base_type == dt.datetime:
try:
return parse_datetime(object_)
except Exception:
return object_
if base_type == dt.date:
try:
return parse_date(object_)
except Exception:
return object_
if base_type == uuid.UUID:
try:
return uuid.UUID(object_)
except Exception:
return object_
if base_type == int:
try:
return int(object_)
except Exception:
return object_
if base_type == bool:
try:
if isinstance(object_, str):
stringified_object = object_.lower()
return stringified_object == "true" or stringified_object == "1"
return bool(object_)
except Exception:
return object_
if inspect.isclass(base_type) and issubclass(base_type, enum.Enum):
try:
return base_type(object_)
except (ValueError, KeyError):
return object_
return object_
def _get_is_populate_by_name(model: typing.Type["Model"]) -> bool:
if IS_PYDANTIC_V2:
return model.model_config.get("populate_by_name", False) # type: ignore # Pydantic v2
return model.__config__.allow_population_by_field_name # type: ignore # Pydantic v1
from pydantic.fields import FieldInfo as _FieldInfo
PydanticField = typing.Union[ModelField, _FieldInfo]
# Pydantic V1 swapped the typing of __fields__'s values from ModelField to FieldInfo
# And so we try to handle both V1 cases, as well as V2 (FieldInfo from model.model_fields)
def _get_model_fields(
model: typing.Type["Model"],
) -> typing.Mapping[str, PydanticField]:
if IS_PYDANTIC_V2:
return model.model_fields # type: ignore # Pydantic v2
else:
return model.__fields__ # type: ignore # Pydantic v1
def _get_field_default(field: PydanticField) -> typing.Any:
try:
value = field.get_default() # type: ignore # Pydantic < v1.10.15
except:
value = field.default
if IS_PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
================================================
FILE: src/cohere/datasets/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse
_dynamic_imports: typing.Dict[str, str] = {
"DatasetsCreateResponse": ".types",
"DatasetsGetResponse": ".types",
"DatasetsGetUsageResponse": ".types",
"DatasetsListResponse": ".types",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["DatasetsCreateResponse", "DatasetsGetResponse", "DatasetsGetUsageResponse", "DatasetsListResponse"]
================================================
FILE: src/cohere/datasets/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
from .. import core
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from ..types.dataset_type import DatasetType
from ..types.dataset_validation_status import DatasetValidationStatus
from .raw_client import AsyncRawDatasetsClient, RawDatasetsClient
from .types.datasets_create_response import DatasetsCreateResponse
from .types.datasets_get_response import DatasetsGetResponse
from .types.datasets_get_usage_response import DatasetsGetUsageResponse
from .types.datasets_list_response import DatasetsListResponse
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class DatasetsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawDatasetsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawDatasetsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawDatasetsClient
"""
return self._raw_client
def list(
self,
*,
dataset_type: typing.Optional[str] = None,
before: typing.Optional[dt.datetime] = None,
after: typing.Optional[dt.datetime] = None,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
validation_status: typing.Optional[DatasetValidationStatus] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> DatasetsListResponse:
"""
List datasets that have been created.
Parameters
----------
dataset_type : typing.Optional[str]
optional filter by dataset type
before : typing.Optional[dt.datetime]
optional filter before a date
after : typing.Optional[dt.datetime]
optional filter after a date
limit : typing.Optional[float]
optional limit to number of results
offset : typing.Optional[float]
optional offset to start of results
validation_status : typing.Optional[DatasetValidationStatus]
optional filter by validation status
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsListResponse
A successful response.
Examples
--------
import datetime
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.datasets.list(
dataset_type="datasetType",
before=datetime.datetime.fromisoformat(
"2024-01-15 09:30:00+00:00",
),
after=datetime.datetime.fromisoformat(
"2024-01-15 09:30:00+00:00",
),
limit=1.1,
offset=1.1,
validation_status="unknown",
)
"""
_response = self._raw_client.list(
dataset_type=dataset_type,
before=before,
after=after,
limit=limit,
offset=offset,
validation_status=validation_status,
request_options=request_options,
)
return _response.data
def create(
self,
*,
name: str,
type: DatasetType,
data: core.File,
keep_original_file: typing.Optional[bool] = None,
skip_malformed_input: typing.Optional[bool] = None,
keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
text_separator: typing.Optional[str] = None,
csv_delimiter: typing.Optional[str] = None,
eval_data: typing.Optional[core.File] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> DatasetsCreateResponse:
"""
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
Parameters
----------
name : str
The name of the uploaded dataset.
type : DatasetType
The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
data : core.File
See core.File for more documentation
keep_original_file : typing.Optional[bool]
Indicates if the original file should be stored.
skip_malformed_input : typing.Optional[bool]
Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
text_separator : typing.Optional[str]
Raw .txt uploads will be split into entries using the text_separator value.
csv_delimiter : typing.Optional[str]
The delimiter used for .csv uploads.
eval_data : typing.Optional[core.File]
See core.File for more documentation
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsCreateResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.datasets.create(
name="name",
type="embed-input",
keep_original_file=True,
skip_malformed_input=True,
text_separator="text_separator",
csv_delimiter="csv_delimiter",
)
"""
_response = self._raw_client.create(
name=name,
type=type,
data=data,
keep_original_file=keep_original_file,
skip_malformed_input=skip_malformed_input,
keep_fields=keep_fields,
optional_fields=optional_fields,
text_separator=text_separator,
csv_delimiter=csv_delimiter,
eval_data=eval_data,
request_options=request_options,
)
return _response.data
def get_usage(self, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetUsageResponse:
"""
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsGetUsageResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.datasets.get_usage()
"""
_response = self._raw_client.get_usage(request_options=request_options)
return _response.data
def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetResponse:
"""
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsGetResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.datasets.get(
id="id",
)
"""
_response = self._raw_client.get(id, request_options=request_options)
return _response.data
def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> typing.Dict[str, typing.Any]:
"""
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
typing.Dict[str, typing.Any]
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.datasets.delete(
id="id",
)
"""
_response = self._raw_client.delete(id, request_options=request_options)
return _response.data
class AsyncDatasetsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawDatasetsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawDatasetsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawDatasetsClient
"""
return self._raw_client
async def list(
self,
*,
dataset_type: typing.Optional[str] = None,
before: typing.Optional[dt.datetime] = None,
after: typing.Optional[dt.datetime] = None,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
validation_status: typing.Optional[DatasetValidationStatus] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> DatasetsListResponse:
"""
List datasets that have been created.
Parameters
----------
dataset_type : typing.Optional[str]
optional filter by dataset type
before : typing.Optional[dt.datetime]
optional filter before a date
after : typing.Optional[dt.datetime]
optional filter after a date
limit : typing.Optional[float]
optional limit to number of results
offset : typing.Optional[float]
optional offset to start of results
validation_status : typing.Optional[DatasetValidationStatus]
optional filter by validation status
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsListResponse
A successful response.
Examples
--------
import asyncio
import datetime
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.datasets.list(
dataset_type="datasetType",
before=datetime.datetime.fromisoformat(
"2024-01-15 09:30:00+00:00",
),
after=datetime.datetime.fromisoformat(
"2024-01-15 09:30:00+00:00",
),
limit=1.1,
offset=1.1,
validation_status="unknown",
)
asyncio.run(main())
"""
_response = await self._raw_client.list(
dataset_type=dataset_type,
before=before,
after=after,
limit=limit,
offset=offset,
validation_status=validation_status,
request_options=request_options,
)
return _response.data
async def create(
self,
*,
name: str,
type: DatasetType,
data: core.File,
keep_original_file: typing.Optional[bool] = None,
skip_malformed_input: typing.Optional[bool] = None,
keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
text_separator: typing.Optional[str] = None,
csv_delimiter: typing.Optional[str] = None,
eval_data: typing.Optional[core.File] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> DatasetsCreateResponse:
"""
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
Parameters
----------
name : str
The name of the uploaded dataset.
type : DatasetType
The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
data : core.File
See core.File for more documentation
keep_original_file : typing.Optional[bool]
Indicates if the original file should be stored.
skip_malformed_input : typing.Optional[bool]
Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
text_separator : typing.Optional[str]
Raw .txt uploads will be split into entries using the text_separator value.
csv_delimiter : typing.Optional[str]
The delimiter used for .csv uploads.
eval_data : typing.Optional[core.File]
See core.File for more documentation
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsCreateResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.datasets.create(
name="name",
type="embed-input",
keep_original_file=True,
skip_malformed_input=True,
text_separator="text_separator",
csv_delimiter="csv_delimiter",
)
asyncio.run(main())
"""
_response = await self._raw_client.create(
name=name,
type=type,
data=data,
keep_original_file=keep_original_file,
skip_malformed_input=skip_malformed_input,
keep_fields=keep_fields,
optional_fields=optional_fields,
text_separator=text_separator,
csv_delimiter=csv_delimiter,
eval_data=eval_data,
request_options=request_options,
)
return _response.data
async def get_usage(self, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetUsageResponse:
"""
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsGetUsageResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.datasets.get_usage()
asyncio.run(main())
"""
_response = await self._raw_client.get_usage(request_options=request_options)
return _response.data
async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> DatasetsGetResponse:
"""
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DatasetsGetResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.datasets.get(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.get(id, request_options=request_options)
return _response.data
async def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> typing.Dict[str, typing.Any]:
"""
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
typing.Dict[str, typing.Any]
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.datasets.delete(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.delete(id, request_options=request_options)
return _response.data
================================================
FILE: src/cohere/datasets/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
from json.decoder import JSONDecodeError
from .. import core
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.datetime_utils import serialize_datetime
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.client_closed_request_error import ClientClosedRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.gateway_timeout_error import GatewayTimeoutError
from ..errors.internal_server_error import InternalServerError
from ..errors.invalid_token_error import InvalidTokenError
from ..errors.not_found_error import NotFoundError
from ..errors.not_implemented_error import NotImplementedError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.too_many_requests_error import TooManyRequestsError
from ..errors.unauthorized_error import UnauthorizedError
from ..errors.unprocessable_entity_error import UnprocessableEntityError
from ..types.dataset_type import DatasetType
from ..types.dataset_validation_status import DatasetValidationStatus
from .types.datasets_create_response import DatasetsCreateResponse
from .types.datasets_get_response import DatasetsGetResponse
from .types.datasets_get_usage_response import DatasetsGetUsageResponse
from .types.datasets_list_response import DatasetsListResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawDatasetsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def list(
self,
*,
dataset_type: typing.Optional[str] = None,
before: typing.Optional[dt.datetime] = None,
after: typing.Optional[dt.datetime] = None,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
validation_status: typing.Optional[DatasetValidationStatus] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[DatasetsListResponse]:
"""
List datasets that have been created.
Parameters
----------
dataset_type : typing.Optional[str]
optional filter by dataset type
before : typing.Optional[dt.datetime]
optional filter before a date
after : typing.Optional[dt.datetime]
optional filter after a date
limit : typing.Optional[float]
optional limit to number of results
offset : typing.Optional[float]
optional offset to start of results
validation_status : typing.Optional[DatasetValidationStatus]
optional filter by validation status
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DatasetsListResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v1/datasets",
method="GET",
params={
"datasetType": dataset_type,
"before": serialize_datetime(before) if before is not None else None,
"after": serialize_datetime(after) if after is not None else None,
"limit": limit,
"offset": offset,
"validationStatus": validation_status,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsListResponse,
construct_type(
type_=DatasetsListResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def create(
self,
*,
name: str,
type: DatasetType,
data: core.File,
keep_original_file: typing.Optional[bool] = None,
skip_malformed_input: typing.Optional[bool] = None,
keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
text_separator: typing.Optional[str] = None,
csv_delimiter: typing.Optional[str] = None,
eval_data: typing.Optional[core.File] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[DatasetsCreateResponse]:
"""
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
Parameters
----------
name : str
The name of the uploaded dataset.
type : DatasetType
The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
data : core.File
See core.File for more documentation
keep_original_file : typing.Optional[bool]
Indicates if the original file should be stored.
skip_malformed_input : typing.Optional[bool]
Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
text_separator : typing.Optional[str]
Raw .txt uploads will be split into entries using the text_separator value.
csv_delimiter : typing.Optional[str]
The delimiter used for .csv uploads.
eval_data : typing.Optional[core.File]
See core.File for more documentation
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DatasetsCreateResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v1/datasets",
method="POST",
params={
"name": name,
"type": type,
"keep_original_file": keep_original_file,
"skip_malformed_input": skip_malformed_input,
"keep_fields": keep_fields,
"optional_fields": optional_fields,
"text_separator": text_separator,
"csv_delimiter": csv_delimiter,
},
data={},
files={
"data": data,
**({"eval_data": eval_data} if eval_data is not None else {}),
},
request_options=request_options,
omit=OMIT,
force_multipart=True,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsCreateResponse,
construct_type(
type_=DatasetsCreateResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def get_usage(
self, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[DatasetsGetUsageResponse]:
"""
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DatasetsGetUsageResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v1/datasets/usage",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsGetUsageResponse,
construct_type(
type_=DatasetsGetUsageResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def get(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[DatasetsGetResponse]:
"""
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DatasetsGetResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/datasets/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsGetResponse,
construct_type(
type_=DatasetsGetResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[typing.Dict[str, typing.Any]]:
"""
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[typing.Dict[str, typing.Any]]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/datasets/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
typing.Dict[str, typing.Any],
construct_type(
type_=typing.Dict[str, typing.Any], # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawDatasetsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def list(
self,
*,
dataset_type: typing.Optional[str] = None,
before: typing.Optional[dt.datetime] = None,
after: typing.Optional[dt.datetime] = None,
limit: typing.Optional[float] = None,
offset: typing.Optional[float] = None,
validation_status: typing.Optional[DatasetValidationStatus] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[DatasetsListResponse]:
"""
List datasets that have been created.
Parameters
----------
dataset_type : typing.Optional[str]
optional filter by dataset type
before : typing.Optional[dt.datetime]
optional filter before a date
after : typing.Optional[dt.datetime]
optional filter after a date
limit : typing.Optional[float]
optional limit to number of results
offset : typing.Optional[float]
optional offset to start of results
validation_status : typing.Optional[DatasetValidationStatus]
optional filter by validation status
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DatasetsListResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/datasets",
method="GET",
params={
"datasetType": dataset_type,
"before": serialize_datetime(before) if before is not None else None,
"after": serialize_datetime(after) if after is not None else None,
"limit": limit,
"offset": offset,
"validationStatus": validation_status,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsListResponse,
construct_type(
type_=DatasetsListResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def create(
self,
*,
name: str,
type: DatasetType,
data: core.File,
keep_original_file: typing.Optional[bool] = None,
skip_malformed_input: typing.Optional[bool] = None,
keep_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
optional_fields: typing.Optional[typing.Union[str, typing.Sequence[str]]] = None,
text_separator: typing.Optional[str] = None,
csv_delimiter: typing.Optional[str] = None,
eval_data: typing.Optional[core.File] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[DatasetsCreateResponse]:
"""
Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information.
Parameters
----------
name : str
The name of the uploaded dataset.
type : DatasetType
The dataset type, which is used to validate the data. The only valid type is `embed-input` used in conjunction with the Embed Jobs API.
data : core.File
See core.File for more documentation
keep_original_file : typing.Optional[bool]
Indicates if the original file should be stored.
skip_malformed_input : typing.Optional[bool]
Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field.
keep_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `keep_fields` are missing from the uploaded file, Dataset validation will fail.
optional_fields : typing.Optional[typing.Union[str, typing.Sequence[str]]]
List of names of fields that will be persisted in the Dataset. By default the Dataset will retain only the required fields indicated in the [schema for the corresponding Dataset type](https://docs.cohere.com/docs/datasets#dataset-types). For example, Datasets of type `embed-input` will drop all fields other than the required `text` field. If any of the fields in `optional_fields` are missing from the uploaded file, Dataset validation will pass.
text_separator : typing.Optional[str]
Raw .txt uploads will be split into entries using the text_separator value.
csv_delimiter : typing.Optional[str]
The delimiter used for .csv uploads.
eval_data : typing.Optional[core.File]
See core.File for more documentation
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DatasetsCreateResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/datasets",
method="POST",
params={
"name": name,
"type": type,
"keep_original_file": keep_original_file,
"skip_malformed_input": skip_malformed_input,
"keep_fields": keep_fields,
"optional_fields": optional_fields,
"text_separator": text_separator,
"csv_delimiter": csv_delimiter,
},
data={},
files={
"data": data,
**({"eval_data": eval_data} if eval_data is not None else {}),
},
request_options=request_options,
omit=OMIT,
force_multipart=True,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsCreateResponse,
construct_type(
type_=DatasetsCreateResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def get_usage(
self, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[DatasetsGetUsageResponse]:
"""
View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DatasetsGetUsageResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/datasets/usage",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsGetUsageResponse,
construct_type(
type_=DatasetsGetUsageResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def get(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[DatasetsGetResponse]:
"""
Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DatasetsGetResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/datasets/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DatasetsGetResponse,
construct_type(
type_=DatasetsGetResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def delete(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[typing.Dict[str, typing.Any]]:
"""
Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually.
Parameters
----------
id : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[typing.Dict[str, typing.Any]]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/datasets/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
typing.Dict[str, typing.Any],
construct_type(
type_=typing.Dict[str, typing.Any], # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/datasets/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .datasets_create_response import DatasetsCreateResponse
from .datasets_get_response import DatasetsGetResponse
from .datasets_get_usage_response import DatasetsGetUsageResponse
from .datasets_list_response import DatasetsListResponse
_dynamic_imports: typing.Dict[str, str] = {
"DatasetsCreateResponse": ".datasets_create_response",
"DatasetsGetResponse": ".datasets_get_response",
"DatasetsGetUsageResponse": ".datasets_get_usage_response",
"DatasetsListResponse": ".datasets_list_response",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["DatasetsCreateResponse", "DatasetsGetResponse", "DatasetsGetUsageResponse", "DatasetsListResponse"]
================================================
FILE: src/cohere/datasets/types/datasets_create_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
class DatasetsCreateResponse(UncheckedBaseModel):
id: typing.Optional[str] = pydantic.Field(default=None)
"""
The dataset ID
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/datasets/types/datasets_get_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from ...types.dataset import Dataset
class DatasetsGetResponse(UncheckedBaseModel):
dataset: Dataset
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/datasets/types/datasets_get_usage_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
class DatasetsGetUsageResponse(UncheckedBaseModel):
organization_usage: typing.Optional[int] = pydantic.Field(default=None)
"""
The total number of bytes used by the organization.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/datasets/types/datasets_list_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from ...types.dataset import Dataset
class DatasetsListResponse(UncheckedBaseModel):
datasets: typing.Optional[typing.List[Dataset]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/embed_jobs/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import CreateEmbedJobRequestTruncate
_dynamic_imports: typing.Dict[str, str] = {"CreateEmbedJobRequestTruncate": ".types"}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["CreateEmbedJobRequestTruncate"]
================================================
FILE: src/cohere/embed_jobs/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from ..types.create_embed_job_response import CreateEmbedJobResponse
from ..types.embed_input_type import EmbedInputType
from ..types.embed_job import EmbedJob
from ..types.embedding_type import EmbeddingType
from ..types.list_embed_job_response import ListEmbedJobResponse
from .raw_client import AsyncRawEmbedJobsClient, RawEmbedJobsClient
from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class EmbedJobsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawEmbedJobsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawEmbedJobsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawEmbedJobsClient
"""
return self._raw_client
def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> ListEmbedJobResponse:
"""
The list embed job endpoint allows users to view all embed jobs history for that specific user.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListEmbedJobResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.embed_jobs.list()
"""
_response = self._raw_client.list(request_options=request_options)
return _response.data
def create(
self,
*,
model: str,
dataset_id: str,
input_type: EmbedInputType,
name: typing.Optional[str] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> CreateEmbedJobResponse:
"""
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
Parameters
----------
model : str
ID of the embedding model.
Available models and corresponding embedding dimensions:
- `embed-english-v3.0` : 1024
- `embed-multilingual-v3.0` : 1024
- `embed-english-light-v3.0` : 384
- `embed-multilingual-light-v3.0` : 384
dataset_id : str
ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
input_type : EmbedInputType
name : typing.Optional[str]
The name of the embed job.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions.
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
truncate : typing.Optional[CreateEmbedJobRequestTruncate]
One of `START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateEmbedJobResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.embed_jobs.create(
model="model",
dataset_id="dataset_id",
input_type="search_document",
)
"""
_response = self._raw_client.create(
model=model,
dataset_id=dataset_id,
input_type=input_type,
name=name,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
return _response.data
def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> EmbedJob:
"""
This API retrieves the details about an embed job started by the same user.
Parameters
----------
id : str
The ID of the embed job to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedJob
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.embed_jobs.get(
id="id",
)
"""
_response = self._raw_client.get(id, request_options=request_options)
return _response.data
def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> None:
"""
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
Parameters
----------
id : str
The ID of the embed job to cancel.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
None
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.embed_jobs.cancel(
id="id",
)
"""
_response = self._raw_client.cancel(id, request_options=request_options)
return _response.data
class AsyncEmbedJobsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawEmbedJobsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawEmbedJobsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawEmbedJobsClient
"""
return self._raw_client
async def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> ListEmbedJobResponse:
"""
The list embed job endpoint allows users to view all embed jobs history for that specific user.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListEmbedJobResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.embed_jobs.list()
asyncio.run(main())
"""
_response = await self._raw_client.list(request_options=request_options)
return _response.data
async def create(
self,
*,
model: str,
dataset_id: str,
input_type: EmbedInputType,
name: typing.Optional[str] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> CreateEmbedJobResponse:
"""
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
Parameters
----------
model : str
ID of the embedding model.
Available models and corresponding embedding dimensions:
- `embed-english-v3.0` : 1024
- `embed-multilingual-v3.0` : 1024
- `embed-english-light-v3.0` : 384
- `embed-multilingual-light-v3.0` : 384
dataset_id : str
ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
input_type : EmbedInputType
name : typing.Optional[str]
The name of the embed job.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions.
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
truncate : typing.Optional[CreateEmbedJobRequestTruncate]
One of `START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateEmbedJobResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.embed_jobs.create(
model="model",
dataset_id="dataset_id",
input_type="search_document",
)
asyncio.run(main())
"""
_response = await self._raw_client.create(
model=model,
dataset_id=dataset_id,
input_type=input_type,
name=name,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
return _response.data
async def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> EmbedJob:
"""
This API retrieves the details about an embed job started by the same user.
Parameters
----------
id : str
The ID of the embed job to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedJob
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.embed_jobs.get(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.get(id, request_options=request_options)
return _response.data
async def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> None:
"""
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
Parameters
----------
id : str
The ID of the embed job to cancel.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
None
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.embed_jobs.cancel(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.cancel(id, request_options=request_options)
return _response.data
================================================
FILE: src/cohere/embed_jobs/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.client_closed_request_error import ClientClosedRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.gateway_timeout_error import GatewayTimeoutError
from ..errors.internal_server_error import InternalServerError
from ..errors.invalid_token_error import InvalidTokenError
from ..errors.not_found_error import NotFoundError
from ..errors.not_implemented_error import NotImplementedError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.too_many_requests_error import TooManyRequestsError
from ..errors.unauthorized_error import UnauthorizedError
from ..errors.unprocessable_entity_error import UnprocessableEntityError
from ..types.create_embed_job_response import CreateEmbedJobResponse
from ..types.embed_input_type import EmbedInputType
from ..types.embed_job import EmbedJob
from ..types.embedding_type import EmbeddingType
from ..types.list_embed_job_response import ListEmbedJobResponse
from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawEmbedJobsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def list(self, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[ListEmbedJobResponse]:
"""
The list embed job endpoint allows users to view all embed jobs history for that specific user.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListEmbedJobResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/embed-jobs",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListEmbedJobResponse,
construct_type(
type_=ListEmbedJobResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def create(
self,
*,
model: str,
dataset_id: str,
input_type: EmbedInputType,
name: typing.Optional[str] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[CreateEmbedJobResponse]:
"""
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
Parameters
----------
model : str
ID of the embedding model.
Available models and corresponding embedding dimensions:
- `embed-english-v3.0` : 1024
- `embed-multilingual-v3.0` : 1024
- `embed-english-light-v3.0` : 384
- `embed-multilingual-light-v3.0` : 384
dataset_id : str
ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
input_type : EmbedInputType
name : typing.Optional[str]
The name of the embed job.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions.
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
truncate : typing.Optional[CreateEmbedJobRequestTruncate]
One of `START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CreateEmbedJobResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/embed-jobs",
method="POST",
json={
"model": model,
"dataset_id": dataset_id,
"input_type": input_type,
"name": name,
"embedding_types": embedding_types,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateEmbedJobResponse,
construct_type(
type_=CreateEmbedJobResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def get(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[EmbedJob]:
"""
This API retrieves the details about an embed job started by the same user.
Parameters
----------
id : str
The ID of the embed job to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[EmbedJob]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/embed-jobs/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedJob,
construct_type(
type_=EmbedJob, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = None) -> HttpResponse[None]:
"""
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
Parameters
----------
id : str
The ID of the embed job to cancel.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[None]
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/embed-jobs/{jsonable_encoder(id)}/cancel",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
return HttpResponse(response=_response, data=None)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawEmbedJobsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def list(
self, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[ListEmbedJobResponse]:
"""
The list embed job endpoint allows users to view all embed jobs history for that specific user.
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListEmbedJobResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/embed-jobs",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListEmbedJobResponse,
construct_type(
type_=ListEmbedJobResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def create(
self,
*,
model: str,
dataset_id: str,
input_type: EmbedInputType,
name: typing.Optional[str] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[CreateEmbedJobRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[CreateEmbedJobResponse]:
"""
This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings.
Parameters
----------
model : str
ID of the embedding model.
Available models and corresponding embedding dimensions:
- `embed-english-v3.0` : 1024
- `embed-multilingual-v3.0` : 1024
- `embed-english-light-v3.0` : 384
- `embed-multilingual-light-v3.0` : 384
dataset_id : str
ID of a [Dataset](https://docs.cohere.com/docs/datasets). The Dataset must be of type `embed-input` and must have a validation status `Validated`
input_type : EmbedInputType
name : typing.Optional[str]
The name of the embed job.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for v3 and newer model versions.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for v3 and newer model versions.
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for v3 and newer model versions.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for v3 and newer model versions.
truncate : typing.Optional[CreateEmbedJobRequestTruncate]
One of `START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CreateEmbedJobResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/embed-jobs",
method="POST",
json={
"model": model,
"dataset_id": dataset_id,
"input_type": input_type,
"name": name,
"embedding_types": embedding_types,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateEmbedJobResponse,
construct_type(
type_=CreateEmbedJobResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def get(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[EmbedJob]:
"""
This API retrieves the details about an embed job started by the same user.
Parameters
----------
id : str
The ID of the embed job to retrieve.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[EmbedJob]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/embed-jobs/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedJob,
construct_type(
type_=EmbedJob, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def cancel(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[None]:
"""
This API allows users to cancel an active embed job. Once invoked, the embedding process will be terminated, and users will be charged for the embeddings processed up to the cancellation point. It's important to note that partial results will not be available to users after cancellation.
Parameters
----------
id : str
The ID of the embed job to cancel.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[None]
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/embed-jobs/{jsonable_encoder(id)}/cancel",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
return AsyncHttpResponse(response=_response, data=None)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/embed_jobs/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .create_embed_job_request_truncate import CreateEmbedJobRequestTruncate
_dynamic_imports: typing.Dict[str, str] = {"CreateEmbedJobRequestTruncate": ".create_embed_job_request_truncate"}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["CreateEmbedJobRequestTruncate"]
================================================
FILE: src/cohere/embed_jobs/types/create_embed_job_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
CreateEmbedJobRequestTruncate = typing.Union[typing.Literal["START", "END"], typing.Any]
================================================
FILE: src/cohere/environment.py
================================================
# This file was auto-generated by Fern from our API Definition.
import enum
class ClientEnvironment(enum.Enum):
PRODUCTION = "https://api.cohere.com"
================================================
FILE: src/cohere/errors/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .bad_request_error import BadRequestError
from .client_closed_request_error import ClientClosedRequestError
from .forbidden_error import ForbiddenError
from .gateway_timeout_error import GatewayTimeoutError
from .internal_server_error import InternalServerError
from .invalid_token_error import InvalidTokenError
from .not_found_error import NotFoundError
from .not_implemented_error import NotImplementedError
from .service_unavailable_error import ServiceUnavailableError
from .too_many_requests_error import TooManyRequestsError
from .unauthorized_error import UnauthorizedError
from .unprocessable_entity_error import UnprocessableEntityError
_dynamic_imports: typing.Dict[str, str] = {
"BadRequestError": ".bad_request_error",
"ClientClosedRequestError": ".client_closed_request_error",
"ForbiddenError": ".forbidden_error",
"GatewayTimeoutError": ".gateway_timeout_error",
"InternalServerError": ".internal_server_error",
"InvalidTokenError": ".invalid_token_error",
"NotFoundError": ".not_found_error",
"NotImplementedError": ".not_implemented_error",
"ServiceUnavailableError": ".service_unavailable_error",
"TooManyRequestsError": ".too_many_requests_error",
"UnauthorizedError": ".unauthorized_error",
"UnprocessableEntityError": ".unprocessable_entity_error",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"BadRequestError",
"ClientClosedRequestError",
"ForbiddenError",
"GatewayTimeoutError",
"InternalServerError",
"InvalidTokenError",
"NotFoundError",
"NotImplementedError",
"ServiceUnavailableError",
"TooManyRequestsError",
"UnauthorizedError",
"UnprocessableEntityError",
]
================================================
FILE: src/cohere/errors/bad_request_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class BadRequestError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=400, headers=headers, body=body)
================================================
FILE: src/cohere/errors/client_closed_request_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class ClientClosedRequestError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=499, headers=headers, body=body)
================================================
FILE: src/cohere/errors/forbidden_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class ForbiddenError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=403, headers=headers, body=body)
================================================
FILE: src/cohere/errors/gateway_timeout_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class GatewayTimeoutError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=504, headers=headers, body=body)
================================================
FILE: src/cohere/errors/internal_server_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class InternalServerError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=500, headers=headers, body=body)
================================================
FILE: src/cohere/errors/invalid_token_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class InvalidTokenError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=498, headers=headers, body=body)
================================================
FILE: src/cohere/errors/not_found_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class NotFoundError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=404, headers=headers, body=body)
================================================
FILE: src/cohere/errors/not_implemented_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class NotImplementedError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=501, headers=headers, body=body)
================================================
FILE: src/cohere/errors/service_unavailable_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class ServiceUnavailableError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=503, headers=headers, body=body)
================================================
FILE: src/cohere/errors/too_many_requests_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class TooManyRequestsError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=429, headers=headers, body=body)
================================================
FILE: src/cohere/errors/unauthorized_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class UnauthorizedError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=401, headers=headers, body=body)
================================================
FILE: src/cohere/errors/unprocessable_entity_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.api_error import ApiError
class UnprocessableEntityError(ApiError):
def __init__(self, body: typing.Any, headers: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(status_code=422, headers=headers, body=body)
================================================
FILE: src/cohere/finetuning/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from . import finetuning
from .finetuning import (
BaseModel,
BaseType,
CreateFinetunedModelResponse,
DeleteFinetunedModelResponse,
Event,
FinetunedModel,
GetFinetunedModelResponse,
Hyperparameters,
ListEventsResponse,
ListFinetunedModelsResponse,
ListTrainingStepMetricsResponse,
LoraTargetModules,
Settings,
Status,
Strategy,
TrainingStepMetrics,
UpdateFinetunedModelResponse,
WandbConfig,
)
_dynamic_imports: typing.Dict[str, str] = {
"BaseModel": ".finetuning",
"BaseType": ".finetuning",
"CreateFinetunedModelResponse": ".finetuning",
"DeleteFinetunedModelResponse": ".finetuning",
"Event": ".finetuning",
"FinetunedModel": ".finetuning",
"GetFinetunedModelResponse": ".finetuning",
"Hyperparameters": ".finetuning",
"ListEventsResponse": ".finetuning",
"ListFinetunedModelsResponse": ".finetuning",
"ListTrainingStepMetricsResponse": ".finetuning",
"LoraTargetModules": ".finetuning",
"Settings": ".finetuning",
"Status": ".finetuning",
"Strategy": ".finetuning",
"TrainingStepMetrics": ".finetuning",
"UpdateFinetunedModelResponse": ".finetuning",
"WandbConfig": ".finetuning",
"finetuning": ".finetuning",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"BaseModel",
"BaseType",
"CreateFinetunedModelResponse",
"DeleteFinetunedModelResponse",
"Event",
"FinetunedModel",
"GetFinetunedModelResponse",
"Hyperparameters",
"ListEventsResponse",
"ListFinetunedModelsResponse",
"ListTrainingStepMetricsResponse",
"LoraTargetModules",
"Settings",
"Status",
"Strategy",
"TrainingStepMetrics",
"UpdateFinetunedModelResponse",
"WandbConfig",
"finetuning",
]
================================================
FILE: src/cohere/finetuning/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from .finetuning.types.create_finetuned_model_response import CreateFinetunedModelResponse
from .finetuning.types.delete_finetuned_model_response import DeleteFinetunedModelResponse
from .finetuning.types.finetuned_model import FinetunedModel
from .finetuning.types.get_finetuned_model_response import GetFinetunedModelResponse
from .finetuning.types.list_events_response import ListEventsResponse
from .finetuning.types.list_finetuned_models_response import ListFinetunedModelsResponse
from .finetuning.types.list_training_step_metrics_response import ListTrainingStepMetricsResponse
from .finetuning.types.settings import Settings
from .finetuning.types.update_finetuned_model_response import UpdateFinetunedModelResponse
from .raw_client import AsyncRawFinetuningClient, RawFinetuningClient
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class FinetuningClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawFinetuningClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawFinetuningClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawFinetuningClient
"""
return self._raw_client
def list_finetuned_models(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListFinetunedModelsResponse:
"""
Returns a list of fine-tuned models that the user has access to.
Parameters
----------
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListFinetunedModelsResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.list_finetuned_models(
page_size=1,
page_token="page_token",
order_by="order_by",
)
"""
_response = self._raw_client.list_finetuned_models(
page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options
)
return _response.data
def create_finetuned_model(
self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None
) -> CreateFinetunedModelResponse:
"""
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
Parameters
----------
request : FinetunedModel
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateFinetunedModelResponse
A successful response.
Examples
--------
from cohere import Client
from cohere.finetuning.finetuning import BaseModel, FinetunedModel, Settings
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.create_finetuned_model(
request=FinetunedModel(
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
),
)
"""
_response = self._raw_client.create_finetuned_model(request=request, request_options=request_options)
return _response.data
def get_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> GetFinetunedModelResponse:
"""
Retrieve a fine-tuned model by its ID.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetFinetunedModelResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.get_finetuned_model(
id="id",
)
"""
_response = self._raw_client.get_finetuned_model(id, request_options=request_options)
return _response.data
def delete_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> DeleteFinetunedModelResponse:
"""
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use.
This operation is irreversible.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DeleteFinetunedModelResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.delete_finetuned_model(
id="id",
)
"""
_response = self._raw_client.delete_finetuned_model(id, request_options=request_options)
return _response.data
def update_finetuned_model(
self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None
) -> UpdateFinetunedModelResponse:
"""
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
Parameters
----------
id : str
FinetunedModel ID.
name : str
FinetunedModel name (e.g. `foobar`).
settings : Settings
FinetunedModel settings such as dataset, hyperparameters...
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
UpdateFinetunedModelResponse
A successful response.
Examples
--------
from cohere import Client
from cohere.finetuning.finetuning import BaseModel, Settings
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.update_finetuned_model(
id="id",
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
)
"""
_response = self._raw_client.update_finetuned_model(
id, name=name, settings=settings, request_options=request_options
)
return _response.data
def list_events(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListEventsResponse:
"""
Returns a list of events that occurred during the life-cycle of the fine-tuned model.
The events are ordered by creation time, with the most recent event first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListEventsResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.list_events(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
order_by="order_by",
)
"""
_response = self._raw_client.list_events(
finetuned_model_id,
page_size=page_size,
page_token=page_token,
order_by=order_by,
request_options=request_options,
)
return _response.data
def list_training_step_metrics(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListTrainingStepMetricsResponse:
"""
Returns a list of metrics measured during the training of a fine-tuned model.
The metrics are ordered by step number, with the most recent step first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListTrainingStepMetricsResponse
A successful response.
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.finetuning.list_training_step_metrics(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
)
"""
_response = self._raw_client.list_training_step_metrics(
finetuned_model_id, page_size=page_size, page_token=page_token, request_options=request_options
)
return _response.data
class AsyncFinetuningClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawFinetuningClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawFinetuningClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawFinetuningClient
"""
return self._raw_client
async def list_finetuned_models(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListFinetunedModelsResponse:
"""
Returns a list of fine-tuned models that the user has access to.
Parameters
----------
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListFinetunedModelsResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.list_finetuned_models(
page_size=1,
page_token="page_token",
order_by="order_by",
)
asyncio.run(main())
"""
_response = await self._raw_client.list_finetuned_models(
page_size=page_size, page_token=page_token, order_by=order_by, request_options=request_options
)
return _response.data
async def create_finetuned_model(
self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None
) -> CreateFinetunedModelResponse:
"""
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
Parameters
----------
request : FinetunedModel
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
CreateFinetunedModelResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
from cohere.finetuning.finetuning import BaseModel, FinetunedModel, Settings
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.create_finetuned_model(
request=FinetunedModel(
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
),
)
asyncio.run(main())
"""
_response = await self._raw_client.create_finetuned_model(request=request, request_options=request_options)
return _response.data
async def get_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> GetFinetunedModelResponse:
"""
Retrieve a fine-tuned model by its ID.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetFinetunedModelResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.get_finetuned_model(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.get_finetuned_model(id, request_options=request_options)
return _response.data
async def delete_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> DeleteFinetunedModelResponse:
"""
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use.
This operation is irreversible.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
DeleteFinetunedModelResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.delete_finetuned_model(
id="id",
)
asyncio.run(main())
"""
_response = await self._raw_client.delete_finetuned_model(id, request_options=request_options)
return _response.data
async def update_finetuned_model(
self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None
) -> UpdateFinetunedModelResponse:
"""
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
Parameters
----------
id : str
FinetunedModel ID.
name : str
FinetunedModel name (e.g. `foobar`).
settings : Settings
FinetunedModel settings such as dataset, hyperparameters...
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
UpdateFinetunedModelResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
from cohere.finetuning.finetuning import BaseModel, Settings
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.update_finetuned_model(
id="id",
name="name",
settings=Settings(
base_model=BaseModel(
base_type="BASE_TYPE_UNSPECIFIED",
),
dataset_id="dataset_id",
),
)
asyncio.run(main())
"""
_response = await self._raw_client.update_finetuned_model(
id, name=name, settings=settings, request_options=request_options
)
return _response.data
async def list_events(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListEventsResponse:
"""
Returns a list of events that occurred during the life-cycle of the fine-tuned model.
The events are ordered by creation time, with the most recent event first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListEventsResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.list_events(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
order_by="order_by",
)
asyncio.run(main())
"""
_response = await self._raw_client.list_events(
finetuned_model_id,
page_size=page_size,
page_token=page_token,
order_by=order_by,
request_options=request_options,
)
return _response.data
async def list_training_step_metrics(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListTrainingStepMetricsResponse:
"""
Returns a list of metrics measured during the training of a fine-tuned model.
The metrics are ordered by step number, with the most recent step first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListTrainingStepMetricsResponse
A successful response.
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.finetuning.list_training_step_metrics(
finetuned_model_id="finetuned_model_id",
page_size=1,
page_token="page_token",
)
asyncio.run(main())
"""
_response = await self._raw_client.list_training_step_metrics(
finetuned_model_id, page_size=page_size, page_token=page_token, request_options=request_options
)
return _response.data
================================================
FILE: src/cohere/finetuning/finetuning/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import (
BaseModel,
BaseType,
CreateFinetunedModelResponse,
DeleteFinetunedModelResponse,
Event,
FinetunedModel,
GetFinetunedModelResponse,
Hyperparameters,
ListEventsResponse,
ListFinetunedModelsResponse,
ListTrainingStepMetricsResponse,
LoraTargetModules,
Settings,
Status,
Strategy,
TrainingStepMetrics,
UpdateFinetunedModelResponse,
WandbConfig,
)
_dynamic_imports: typing.Dict[str, str] = {
"BaseModel": ".types",
"BaseType": ".types",
"CreateFinetunedModelResponse": ".types",
"DeleteFinetunedModelResponse": ".types",
"Event": ".types",
"FinetunedModel": ".types",
"GetFinetunedModelResponse": ".types",
"Hyperparameters": ".types",
"ListEventsResponse": ".types",
"ListFinetunedModelsResponse": ".types",
"ListTrainingStepMetricsResponse": ".types",
"LoraTargetModules": ".types",
"Settings": ".types",
"Status": ".types",
"Strategy": ".types",
"TrainingStepMetrics": ".types",
"UpdateFinetunedModelResponse": ".types",
"WandbConfig": ".types",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"BaseModel",
"BaseType",
"CreateFinetunedModelResponse",
"DeleteFinetunedModelResponse",
"Event",
"FinetunedModel",
"GetFinetunedModelResponse",
"Hyperparameters",
"ListEventsResponse",
"ListFinetunedModelsResponse",
"ListTrainingStepMetricsResponse",
"LoraTargetModules",
"Settings",
"Status",
"Strategy",
"TrainingStepMetrics",
"UpdateFinetunedModelResponse",
"WandbConfig",
]
================================================
FILE: src/cohere/finetuning/finetuning/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .base_model import BaseModel
from .base_type import BaseType
from .create_finetuned_model_response import CreateFinetunedModelResponse
from .delete_finetuned_model_response import DeleteFinetunedModelResponse
from .event import Event
from .finetuned_model import FinetunedModel
from .get_finetuned_model_response import GetFinetunedModelResponse
from .hyperparameters import Hyperparameters
from .list_events_response import ListEventsResponse
from .list_finetuned_models_response import ListFinetunedModelsResponse
from .list_training_step_metrics_response import ListTrainingStepMetricsResponse
from .lora_target_modules import LoraTargetModules
from .settings import Settings
from .status import Status
from .strategy import Strategy
from .training_step_metrics import TrainingStepMetrics
from .update_finetuned_model_response import UpdateFinetunedModelResponse
from .wandb_config import WandbConfig
_dynamic_imports: typing.Dict[str, str] = {
"BaseModel": ".base_model",
"BaseType": ".base_type",
"CreateFinetunedModelResponse": ".create_finetuned_model_response",
"DeleteFinetunedModelResponse": ".delete_finetuned_model_response",
"Event": ".event",
"FinetunedModel": ".finetuned_model",
"GetFinetunedModelResponse": ".get_finetuned_model_response",
"Hyperparameters": ".hyperparameters",
"ListEventsResponse": ".list_events_response",
"ListFinetunedModelsResponse": ".list_finetuned_models_response",
"ListTrainingStepMetricsResponse": ".list_training_step_metrics_response",
"LoraTargetModules": ".lora_target_modules",
"Settings": ".settings",
"Status": ".status",
"Strategy": ".strategy",
"TrainingStepMetrics": ".training_step_metrics",
"UpdateFinetunedModelResponse": ".update_finetuned_model_response",
"WandbConfig": ".wandb_config",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"BaseModel",
"BaseType",
"CreateFinetunedModelResponse",
"DeleteFinetunedModelResponse",
"Event",
"FinetunedModel",
"GetFinetunedModelResponse",
"Hyperparameters",
"ListEventsResponse",
"ListFinetunedModelsResponse",
"ListTrainingStepMetricsResponse",
"LoraTargetModules",
"Settings",
"Status",
"Strategy",
"TrainingStepMetrics",
"UpdateFinetunedModelResponse",
"WandbConfig",
]
================================================
FILE: src/cohere/finetuning/finetuning/types/base_model.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .base_type import BaseType
from .strategy import Strategy
class BaseModel(UncheckedBaseModel):
"""
The base model used for fine-tuning.
"""
name: typing.Optional[str] = pydantic.Field(default=None)
"""
The name of the base model.
"""
version: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. The version of the base model.
"""
base_type: BaseType = pydantic.Field()
"""
The type of the base model.
"""
strategy: typing.Optional[Strategy] = pydantic.Field(default=None)
"""
Deprecated: The fine-tuning strategy.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/base_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
BaseType = typing.Union[
typing.Literal[
"BASE_TYPE_UNSPECIFIED",
"BASE_TYPE_GENERATIVE",
"BASE_TYPE_CLASSIFICATION",
"BASE_TYPE_RERANK",
"BASE_TYPE_CHAT",
],
typing.Any,
]
================================================
FILE: src/cohere/finetuning/finetuning/types/create_finetuned_model_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .finetuned_model import FinetunedModel
class CreateFinetunedModelResponse(UncheckedBaseModel):
"""
Response to request to create a fine-tuned model.
"""
finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None)
"""
Information about the fine-tuned model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/delete_finetuned_model_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
DeleteFinetunedModelResponse = typing.Dict[str, typing.Any]
"""
Response to request to delete a fine-tuned model.
"""
================================================
FILE: src/cohere/finetuning/finetuning/types/event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .status import Status
class Event(UncheckedBaseModel):
"""
A change in status of a fine-tuned model.
"""
user_id: typing.Optional[str] = pydantic.Field(default=None)
"""
ID of the user who initiated the event. Empty if initiated by the system.
"""
status: typing.Optional[Status] = pydantic.Field(default=None)
"""
Status of the fine-tuned model.
"""
created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
Timestamp when the event happened.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/finetuned_model.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .settings import Settings
from .status import Status
class FinetunedModel(UncheckedBaseModel):
"""
This resource represents a fine-tuned model.
"""
id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. FinetunedModel ID.
"""
name: str = pydantic.Field()
"""
FinetunedModel name (e.g. `foobar`).
"""
creator_id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. User ID of the creator.
"""
organization_id: typing.Optional[str] = pydantic.Field(default=None)
"""
read-only. Organization ID.
"""
settings: Settings = pydantic.Field()
"""
FinetunedModel settings such as dataset, hyperparameters...
"""
status: typing.Optional[Status] = pydantic.Field(default=None)
"""
read-only. Current stage in the life-cycle of the fine-tuned model.
"""
created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Creation timestamp.
"""
updated_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Latest update timestamp.
"""
completed_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Timestamp for the completed fine-tuning.
"""
last_used: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
read-only. Deprecated: Timestamp for the latest request to this fine-tuned model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/get_finetuned_model_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .finetuned_model import FinetunedModel
class GetFinetunedModelResponse(UncheckedBaseModel):
"""
Response to a request to get a fine-tuned model.
"""
finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None)
"""
Information about the fine-tuned model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/hyperparameters.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .lora_target_modules import LoraTargetModules
class Hyperparameters(UncheckedBaseModel):
"""
The fine-tuning hyperparameters.
"""
early_stopping_patience: typing.Optional[int] = pydantic.Field(default=None)
"""
Stops training if the loss metric does not improve beyond the value of
`early_stopping_threshold` after this many times of evaluation.
"""
early_stopping_threshold: typing.Optional[float] = pydantic.Field(default=None)
"""
How much the loss must improve to prevent early stopping.
"""
train_batch_size: typing.Optional[int] = pydantic.Field(default=None)
"""
The batch size is the number of training examples included in a single
training pass.
"""
train_epochs: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of epochs to train for.
"""
learning_rate: typing.Optional[float] = pydantic.Field(default=None)
"""
The learning rate to be used during training.
"""
lora_alpha: typing.Optional[int] = pydantic.Field(default=None)
"""
Controls the scaling factor for LoRA updates. Higher values make the
updates more impactful.
"""
lora_rank: typing.Optional[int] = pydantic.Field(default=None)
"""
Specifies the rank for low-rank matrices. Lower ranks reduce parameters
but may limit model flexibility.
"""
lora_target_modules: typing.Optional[LoraTargetModules] = pydantic.Field(default=None)
"""
The combination of LoRA modules to target.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/list_events_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .event import Event
class ListEventsResponse(UncheckedBaseModel):
"""
Response to a request to list events of a fine-tuned model.
"""
events: typing.Optional[typing.List[Event]] = pydantic.Field(default=None)
"""
List of events for the fine-tuned model.
"""
next_page_token: typing.Optional[str] = pydantic.Field(default=None)
"""
Pagination token to retrieve the next page of results. If the value is "",
it means no further results for the request.
"""
total_size: typing.Optional[int] = pydantic.Field(default=None)
"""
Total count of results.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/list_finetuned_models_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .finetuned_model import FinetunedModel
class ListFinetunedModelsResponse(UncheckedBaseModel):
"""
Response to a request to list fine-tuned models.
"""
finetuned_models: typing.Optional[typing.List[FinetunedModel]] = pydantic.Field(default=None)
"""
List of fine-tuned models matching the request.
"""
next_page_token: typing.Optional[str] = pydantic.Field(default=None)
"""
Pagination token to retrieve the next page of results. If the value is "",
it means no further results for the request.
"""
total_size: typing.Optional[int] = pydantic.Field(default=None)
"""
Total count of results.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/list_training_step_metrics_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .training_step_metrics import TrainingStepMetrics
class ListTrainingStepMetricsResponse(UncheckedBaseModel):
"""
Response to a request to list training-step metrics of a fine-tuned model.
"""
step_metrics: typing.Optional[typing.List[TrainingStepMetrics]] = pydantic.Field(default=None)
"""
The metrics for each step the evaluation was run on.
"""
next_page_token: typing.Optional[str] = pydantic.Field(default=None)
"""
Pagination token to retrieve the next page of results. If the value is "",
it means no further results for the request.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/lora_target_modules.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
LoraTargetModules = typing.Union[
typing.Literal[
"LORA_TARGET_MODULES_UNSPECIFIED",
"LORA_TARGET_MODULES_QV",
"LORA_TARGET_MODULES_QKVO",
"LORA_TARGET_MODULES_QKVO_FFN",
],
typing.Any,
]
================================================
FILE: src/cohere/finetuning/finetuning/types/settings.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .base_model import BaseModel
from .hyperparameters import Hyperparameters
from .wandb_config import WandbConfig
class Settings(UncheckedBaseModel):
"""
The configuration used for fine-tuning.
"""
base_model: BaseModel = pydantic.Field()
"""
The base model to fine-tune.
"""
dataset_id: str = pydantic.Field()
"""
The data used for training and evaluating the fine-tuned model.
"""
hyperparameters: typing.Optional[Hyperparameters] = pydantic.Field(default=None)
"""
Fine-tuning hyper-parameters.
"""
multi_label: typing.Optional[bool] = pydantic.Field(default=None)
"""
read-only. Whether the model is single-label or multi-label (only for classification).
"""
wandb: typing.Optional[WandbConfig] = pydantic.Field(default=None)
"""
The Weights & Biases configuration (Chat fine-tuning only).
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/status.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
Status = typing.Union[
typing.Literal[
"STATUS_UNSPECIFIED",
"STATUS_FINETUNING",
"STATUS_DEPLOYING_API",
"STATUS_READY",
"STATUS_FAILED",
"STATUS_DELETED",
"STATUS_TEMPORARILY_OFFLINE",
"STATUS_PAUSED",
"STATUS_QUEUED",
],
typing.Any,
]
================================================
FILE: src/cohere/finetuning/finetuning/types/strategy.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
Strategy = typing.Union[typing.Literal["STRATEGY_UNSPECIFIED", "STRATEGY_VANILLA", "STRATEGY_TFEW"], typing.Any]
================================================
FILE: src/cohere/finetuning/finetuning/types/training_step_metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
class TrainingStepMetrics(UncheckedBaseModel):
"""
The evaluation metrics at a given step of the training of a fine-tuned model.
"""
created_at: typing.Optional[dt.datetime] = pydantic.Field(default=None)
"""
Creation timestamp.
"""
step_number: typing.Optional[int] = pydantic.Field(default=None)
"""
Step number.
"""
metrics: typing.Optional[typing.Dict[str, float]] = pydantic.Field(default=None)
"""
Map of names and values for each evaluation metrics.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/update_finetuned_model_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
from .finetuned_model import FinetunedModel
class UpdateFinetunedModelResponse(UncheckedBaseModel):
"""
Response to a request to update a fine-tuned model.
"""
finetuned_model: typing.Optional[FinetunedModel] = pydantic.Field(default=None)
"""
Information about the fine-tuned model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/finetuning/types/wandb_config.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ....core.pydantic_utilities import IS_PYDANTIC_V2
from ....core.unchecked_base_model import UncheckedBaseModel
class WandbConfig(UncheckedBaseModel):
"""
The Weights & Biases configuration.
"""
project: str = pydantic.Field()
"""
The WandB project name to be used during training.
"""
api_key: str = pydantic.Field()
"""
The WandB API key to be used during training.
"""
entity: typing.Optional[str] = pydantic.Field(default=None)
"""
The WandB entity name to be used during training.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/finetuning/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.serialization import convert_and_respect_annotation_metadata
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.internal_server_error import InternalServerError
from ..errors.not_found_error import NotFoundError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.unauthorized_error import UnauthorizedError
from .finetuning.types.create_finetuned_model_response import CreateFinetunedModelResponse
from .finetuning.types.delete_finetuned_model_response import DeleteFinetunedModelResponse
from .finetuning.types.finetuned_model import FinetunedModel
from .finetuning.types.get_finetuned_model_response import GetFinetunedModelResponse
from .finetuning.types.list_events_response import ListEventsResponse
from .finetuning.types.list_finetuned_models_response import ListFinetunedModelsResponse
from .finetuning.types.list_training_step_metrics_response import ListTrainingStepMetricsResponse
from .finetuning.types.settings import Settings
from .finetuning.types.update_finetuned_model_response import UpdateFinetunedModelResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawFinetuningClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def list_finetuned_models(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListFinetunedModelsResponse]:
"""
Returns a list of fine-tuned models that the user has access to.
Parameters
----------
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListFinetunedModelsResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v1/finetuning/finetuned-models",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListFinetunedModelsResponse,
construct_type(
type_=ListFinetunedModelsResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def create_finetuned_model(
self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[CreateFinetunedModelResponse]:
"""
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
Parameters
----------
request : FinetunedModel
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CreateFinetunedModelResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
"v1/finetuning/finetuned-models",
method="POST",
json=convert_and_respect_annotation_metadata(object_=request, annotation=FinetunedModel, direction="write"),
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateFinetunedModelResponse,
construct_type(
type_=CreateFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def get_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[GetFinetunedModelResponse]:
"""
Retrieve a fine-tuned model by its ID.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[GetFinetunedModelResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetFinetunedModelResponse,
construct_type(
type_=GetFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def delete_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[DeleteFinetunedModelResponse]:
"""
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use.
This operation is irreversible.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DeleteFinetunedModelResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DeleteFinetunedModelResponse,
construct_type(
type_=DeleteFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def update_finetuned_model(
self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[UpdateFinetunedModelResponse]:
"""
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
Parameters
----------
id : str
FinetunedModel ID.
name : str
FinetunedModel name (e.g. `foobar`).
settings : Settings
FinetunedModel settings such as dataset, hyperparameters...
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[UpdateFinetunedModelResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="PATCH",
json={
"name": name,
"settings": convert_and_respect_annotation_metadata(
object_=settings, annotation=Settings, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
UpdateFinetunedModelResponse,
construct_type(
type_=UpdateFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def list_events(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListEventsResponse]:
"""
Returns a list of events that occurred during the life-cycle of the fine-tuned model.
The events are ordered by creation time, with the most recent event first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListEventsResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/events",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListEventsResponse,
construct_type(
type_=ListEventsResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def list_training_step_metrics(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListTrainingStepMetricsResponse]:
"""
Returns a list of metrics measured during the training of a fine-tuned model.
The metrics are ordered by step number, with the most recent step first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListTrainingStepMetricsResponse]
A successful response.
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/training-step-metrics",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListTrainingStepMetricsResponse,
construct_type(
type_=ListTrainingStepMetricsResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawFinetuningClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def list_finetuned_models(
self,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListFinetunedModelsResponse]:
"""
Returns a list of fine-tuned models that the user has access to.
Parameters
----------
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListFinetunedModelsResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/finetuning/finetuned-models",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListFinetunedModelsResponse,
construct_type(
type_=ListFinetunedModelsResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def create_finetuned_model(
self, *, request: FinetunedModel, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[CreateFinetunedModelResponse]:
"""
Creates a new fine-tuned model. The model will be trained on the dataset specified in the request body. The training process may take some time, and the model will be available once the training is complete.
Parameters
----------
request : FinetunedModel
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CreateFinetunedModelResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/finetuning/finetuned-models",
method="POST",
json=convert_and_respect_annotation_metadata(object_=request, annotation=FinetunedModel, direction="write"),
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CreateFinetunedModelResponse,
construct_type(
type_=CreateFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def get_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[GetFinetunedModelResponse]:
"""
Retrieve a fine-tuned model by its ID.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[GetFinetunedModelResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetFinetunedModelResponse,
construct_type(
type_=GetFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def delete_finetuned_model(
self, id: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[DeleteFinetunedModelResponse]:
"""
Deletes a fine-tuned model. The model will be removed from the system and will no longer be available for use.
This operation is irreversible.
Parameters
----------
id : str
The fine-tuned model ID.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DeleteFinetunedModelResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="DELETE",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DeleteFinetunedModelResponse,
construct_type(
type_=DeleteFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def update_finetuned_model(
self, id: str, *, name: str, settings: Settings, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[UpdateFinetunedModelResponse]:
"""
Updates the fine-tuned model with the given ID. The model will be updated with the new settings and name provided in the request body.
Parameters
----------
id : str
FinetunedModel ID.
name : str
FinetunedModel name (e.g. `foobar`).
settings : Settings
FinetunedModel settings such as dataset, hyperparameters...
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[UpdateFinetunedModelResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(id)}",
method="PATCH",
json={
"name": name,
"settings": convert_and_respect_annotation_metadata(
object_=settings, annotation=Settings, direction="write"
),
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
UpdateFinetunedModelResponse,
construct_type(
type_=UpdateFinetunedModelResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def list_events(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
order_by: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListEventsResponse]:
"""
Returns a list of events that occurred during the life-cycle of the fine-tuned model.
The events are ordered by creation time, with the most recent event first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
order_by : typing.Optional[str]
Comma separated list of fields. For example: "created_at,name". The default
sorting order is ascending. To specify descending order for a field, append
" desc" to the field name. For example: "created_at desc,name".
Supported sorting fields:
- created_at (default)
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListEventsResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/events",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"order_by": order_by,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListEventsResponse,
construct_type(
type_=ListEventsResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def list_training_step_metrics(
self,
finetuned_model_id: str,
*,
page_size: typing.Optional[int] = None,
page_token: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListTrainingStepMetricsResponse]:
"""
Returns a list of metrics measured during the training of a fine-tuned model.
The metrics are ordered by step number, with the most recent step first.
The list can be paginated using `page_size` and `page_token` parameters.
Parameters
----------
finetuned_model_id : str
The parent fine-tuned model ID.
page_size : typing.Optional[int]
Maximum number of results to be returned by the server. If 0, defaults to
50.
page_token : typing.Optional[str]
Request a specific page of the list results.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListTrainingStepMetricsResponse]
A successful response.
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/finetuning/finetuned-models/{jsonable_encoder(finetuned_model_id)}/training-step-metrics",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListTrainingStepMetricsResponse,
construct_type(
type_=ListTrainingStepMetricsResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/manually_maintained/__init__.py
================================================
# This module ensures overrides are applied early in the import process
# Import overrides to trigger backwards compatibility patches
from .. import overrides # noqa: F401
================================================
FILE: src/cohere/manually_maintained/cache.py
================================================
import typing
import time
class CacheMixin:
# A simple in-memory cache with TTL (thread safe). This is used to cache tokenizers at the moment.
_cache: typing.Dict[str, typing.Tuple[typing.Optional[float], typing.Any]] = dict()
def _cache_get(self, key: str) -> typing.Any:
val = self._cache.get(key)
if val is None:
return None
expiry_timestamp, value = val
if expiry_timestamp is None or expiry_timestamp > time.time():
return value
del self._cache[key] # remove expired cache entry
def _cache_set(self, key: str, value: typing.Any, ttl: int = 60 * 60) -> None:
expiry_timestamp = None
if ttl is not None:
expiry_timestamp = time.time() + ttl
self._cache[key] = (expiry_timestamp, value)
================================================
FILE: src/cohere/manually_maintained/cohere_aws/__init__.py
================================================
from .client import Client
from .error import CohereError
from .mode import Mode
================================================
FILE: src/cohere/manually_maintained/cohere_aws/chat.py
================================================
from .response import CohereObject
from .error import CohereError
from .mode import Mode
from typing import List, Optional, Generator, Dict, Any, Union
from enum import Enum
import json
# Tools
class ToolParameterDefinitionsValue(CohereObject, dict):
def __init__(
self,
type: str,
description: str,
required: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.__dict__ = self
self.type = type
self.description = description
if required is not None:
self.required = required
class Tool(CohereObject, dict):
def __init__(
self,
name: str,
description: str,
parameter_definitions: Optional[Dict[str, ToolParameterDefinitionsValue]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.__dict__ = self
self.name = name
self.description = description
if parameter_definitions is not None:
self.parameter_definitions = parameter_definitions
class ToolCall(CohereObject, dict):
def __init__(
self,
name: str,
parameters: Dict[str, Any],
generation_id: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.__dict__ = self
self.name = name
self.parameters = parameters
self.generation_id = generation_id
@classmethod
def from_dict(cls, tool_call_res: Dict[str, Any]) -> "ToolCall":
return cls(
name=tool_call_res.get("name"),
parameters=tool_call_res.get("parameters"),
generation_id=tool_call_res.get("generation_id"),
)
@classmethod
def from_list(cls, tool_calls_res: Optional[List[Dict[str, Any]]]) -> Optional[List["ToolCall"]]:
if tool_calls_res is None or not isinstance(tool_calls_res, list):
return None
return [ToolCall.from_dict(tc) for tc in tool_calls_res]
# Chat
class Chat(CohereObject):
def __init__(
self,
response_id: str,
generation_id: str,
text: str,
chat_history: Optional[List[Dict[str, Any]]] = None,
preamble: Optional[str] = None,
finish_reason: Optional[str] = None,
token_count: Optional[Dict[str, int]] = None,
tool_calls: Optional[List[ToolCall]] = None,
citations: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, Any]]] = None,
search_results: Optional[List[Dict[str, Any]]] = None,
search_queries: Optional[List[Dict[str, Any]]] = None,
is_search_required: Optional[bool] = None,
) -> None:
self.response_id = response_id
self.generation_id = generation_id
self.text = text
self.chat_history = chat_history
self.preamble = preamble
self.finish_reason = finish_reason
self.token_count = token_count
self.tool_calls = tool_calls
self.citations = citations
self.documents = documents
self.search_results = search_results
self.search_queries = search_queries
self.is_search_required = is_search_required
@classmethod
def from_dict(cls, response: Dict[str, Any]) -> "Chat":
return cls(
response_id=response["response_id"],
generation_id=response.get("generation_id"), # optional
text=response.get("text"),
chat_history=response.get("chat_history"), # optional
preamble=response.get("preamble"), # optional
token_count=response.get("token_count"),
is_search_required=response.get("is_search_required"), # optional
citations=response.get("citations"), # optional
documents=response.get("documents"), # optional
search_results=response.get("search_results"), # optional
search_queries=response.get("search_queries"), # optional
finish_reason=response.get("finish_reason"),
tool_calls=ToolCall.from_list(response.get("tool_calls")), # optional
)
# ---------------|
# Steaming event |
# ---------------|
class StreamEvent(str, Enum):
STREAM_START = "stream-start"
SEARCH_QUERIES_GENERATION = "search-queries-generation"
SEARCH_RESULTS = "search-results"
TEXT_GENERATION = "text-generation"
TOOL_CALLS_GENERATION = "tool-calls-generation"
CITATION_GENERATION = "citation-generation"
STREAM_END = "stream-end"
class StreamResponse(CohereObject):
def __init__(
self,
is_finished: bool,
event_type: Union[StreamEvent, str],
index: Optional[int],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.is_finished = is_finished
self.index = index
self.event_type = event_type
class StreamStart(StreamResponse):
def __init__(
self,
generation_id: str,
conversation_id: Optional[str],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.generation_id = generation_id
self.conversation_id = conversation_id
class StreamTextGeneration(StreamResponse):
def __init__(
self,
text: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.text = text
class StreamCitationGeneration(StreamResponse):
def __init__(
self,
citations: Optional[List[Dict[str, Any]]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.citations = citations
class StreamQueryGeneration(StreamResponse):
def __init__(
self,
search_queries: Optional[List[Dict[str, Any]]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.search_queries = search_queries
class StreamSearchResults(StreamResponse):
def __init__(
self,
search_results: Optional[List[Dict[str, Any]]],
documents: Optional[List[Dict[str, Any]]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.search_results = search_results
self.documents = documents
class StreamEnd(StreamResponse):
def __init__(
self,
finish_reason: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.finish_reason = finish_reason
class ChatToolCallsGenerationEvent(StreamResponse):
def __init__(
self,
tool_calls: Optional[List[ToolCall]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.tool_calls = tool_calls
class StreamingChat(CohereObject):
def __init__(self, stream_response, mode):
self.stream_response = stream_response
self.text = None
self.response_id = None
self.generation_id = None
self.preamble = None
self.prompt = None
self.chat_history = None
self.finish_reason = None
self.token_count = None
self.is_search_required = None
self.citations = None
self.documents = None
self.search_results = None
self.search_queries = None
self.tool_calls = None
self.bytes = bytearray()
if mode == Mode.SAGEMAKER:
self.payload_key = "PayloadPart"
self.bytes_key = "Bytes"
elif mode == Mode.BEDROCK:
self.payload_key = "chunk"
self.bytes_key = "bytes"
def _make_response_item(self, index, streaming_item) -> Any:
event_type = streaming_item.get("event_type")
if event_type == StreamEvent.STREAM_START:
self.conversation_id = streaming_item.get("conversation_id")
self.generation_id = streaming_item.get("generation_id")
return StreamStart(
conversation_id=self.conversation_id,
generation_id=self.generation_id,
is_finished=False,
event_type=event_type,
index=index,
)
elif event_type == StreamEvent.SEARCH_QUERIES_GENERATION:
search_queries = streaming_item.get("search_queries")
return StreamQueryGeneration(
search_queries=search_queries, is_finished=False, event_type=event_type, index=index
)
elif event_type == StreamEvent.SEARCH_RESULTS:
search_results = streaming_item.get("search_results")
documents = streaming_item.get("documents")
return StreamSearchResults(
search_results=search_results,
documents=documents,
is_finished=False,
event_type=event_type,
index=index,
)
elif event_type == StreamEvent.TEXT_GENERATION:
text = streaming_item.get("text")
return StreamTextGeneration(text=text, is_finished=False, event_type=event_type, index=index)
elif event_type == StreamEvent.CITATION_GENERATION:
citations = streaming_item.get("citations")
return StreamCitationGeneration(citations=citations, is_finished=False, event_type=event_type, index=index)
elif event_type == StreamEvent.TOOL_CALLS_GENERATION:
tool_calls = ToolCall.from_list(streaming_item.get("tool_calls"))
return ChatToolCallsGenerationEvent(
tool_calls=tool_calls, is_finished=False, event_type=event_type, index=index
)
elif event_type == StreamEvent.STREAM_END:
response = streaming_item.get("response")
finish_reason = streaming_item.get("finish_reason")
self.finish_reason = finish_reason
if response is None:
return None
self.response_id = response.get("response_id")
self.conversation_id = response.get("conversation_id")
self.text = response.get("text")
self.generation_id = response.get("generation_id")
self.preamble = response.get("preamble")
self.prompt = response.get("prompt")
self.chat_history = response.get("chat_history")
self.token_count = response.get("token_count")
self.is_search_required = response.get("is_search_required") # optional
self.citations = response.get("citations") # optional
self.documents = response.get("documents") # optional
self.search_results = response.get("search_results") # optional
self.search_queries = response.get("search_queries") # optional
self.tool_calls = ToolCall.from_list(response.get("tool_calls")) # optional
return StreamEnd(finish_reason=finish_reason, is_finished=True, event_type=event_type, index=index)
return None
def __iter__(self) -> Generator[StreamResponse, None, None]:
index = 0
for payload in self.stream_response:
self.bytes.extend(payload[self.payload_key][self.bytes_key])
try:
item = self._make_response_item(index, json.loads(self.bytes))
except json.decoder.JSONDecodeError:
# payload contained only a partion JSON object
continue
self.bytes = bytearray()
if item is not None:
index += 1
yield item
================================================
FILE: src/cohere/manually_maintained/cohere_aws/classification.py
================================================
from .response import CohereObject
from typing import Any, Dict, Iterator, List, Literal, Union
Prediction = Union[str, int, List[str], List[int]]
ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any]
class Classification(CohereObject):
def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None:
# Prediction is the old format (version 1 of classification-finetuning)
# ClassificationDict is the new format (version 2 of classification-finetuning).
# It also contains the original text and the labels' confidence scores of the prediction
self.classification = classification
def is_multilabel(self) -> bool:
if isinstance(self.classification, list):
return True
elif isinstance(self.classification, (int, str)):
return False
return isinstance(self.classification["prediction"], list)
@property
def prediction(self) -> Prediction:
if isinstance(self.classification, (list, int, str)):
return self.classification
return self.classification["prediction"]
@property
def confidence(self) -> List[float]:
if isinstance(self.classification, (list, int, str)):
raise ValueError(
"Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package"
)
return self.classification["confidence"]
@property
def text(self) -> str:
if isinstance(self.classification, (list, int, str)):
raise ValueError(
"Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package"
)
return self.classification["text"]
class Classifications(CohereObject):
def __init__(self, classifications: List[Classification]) -> None:
self.classifications = classifications
if len(self.classifications) > 0:
assert all(
[c.is_multilabel() == self.is_multilabel() for c in self.classifications]
), "All classifications must be of the same type (single-label or multi-label)"
def __iter__(self) -> Iterator:
return iter(self.classifications)
def __len__(self) -> int:
return len(self.classifications)
def is_multilabel(self) -> bool:
return len(self.classifications) > 0 and self.classifications[0].is_multilabel()
================================================
FILE: src/cohere/manually_maintained/cohere_aws/client.py
================================================
import json
import os
import tarfile
import tempfile
import time
from typing import Any, Dict, List, Optional, Union
from .classification import Classification, Classifications
from .embeddings import Embeddings
from .error import CohereError
from .generation import Generations, StreamingGenerations
from .chat import Chat, StreamingChat
from .rerank import Reranking
from .summary import Summary
from .mode import Mode
import typing
from ..lazy_aws_deps import lazy_boto3, lazy_botocore, lazy_sagemaker
class Client:
def __init__(
self,
aws_region: typing.Optional[str] = None,
mode: Mode = Mode.SAGEMAKER,
):
"""
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self.mode = mode
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
if self.mode == Mode.SAGEMAKER:
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
elif self.mode == Mode.BEDROCK:
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
self._sess = None
self._endpoint_name = None
def _require_sagemaker(self) -> None:
if self.mode != Mode.SAGEMAKER:
raise CohereError("This method is only supported in SageMaker mode.")
def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
self._service_client.describe_endpoint(EndpointName=endpoint_name)
except lazy_botocore().ClientError:
return False
return True
def connect_to_endpoint(self, endpoint_name: str) -> None:
"""Connects to an existing SageMaker endpoint.
Args:
endpoint_name (str): The name of the endpoint.
Raises:
CohereError: Connection to the endpoint failed.
"""
self._require_sagemaker()
if not self._does_endpoint_exist(endpoint_name):
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
self._endpoint_name = endpoint_name
def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
"""
Compress an S3 folder which contains one or several fine-tuned models to a tar file.
If the S3 folder contains only one fine-tuned model, it simply returns the path to that model.
If the S3 folder contains several fine-tuned models, it download all models, aggregates them into a single
tar.gz file.
Args:
s3_models_dir (str): S3 URI pointing to a folder
Returns:
str: S3 URI pointing to the `models.tar.gz` file
"""
s3_models_dir = s3_models_dir.rstrip("/") + "/"
# Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz
s3_tar_models = [
s3_path
for s3_path in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
if (
s3_path.endswith(".tar.gz") # only .tar.gz files
and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating
and (s3_path.rsplit("/", 1)[0] == s3_models_dir[:-1]) # only files at the root of s3_models_dir
)
]
if len(s3_tar_models) == 0:
raise CohereError(f"No fine-tuned models found in {s3_models_dir}")
elif len(s3_tar_models) == 1:
print(f"Found one fine-tuned model: {s3_tar_models[0]}")
return s3_tar_models[0]
# More than one fine-tuned model found, need to aggregate them into a single .tar.gz file
with tempfile.TemporaryDirectory() as tmpdir:
local_tar_models_dir = os.path.join(tmpdir, "tar")
local_models_dir = os.path.join(tmpdir, "models")
# Download and extract all fine-tuned models
for s3_tar_model in s3_tar_models:
print(f"Adding fine-tuned model: {s3_tar_model}")
lazy_sagemaker().s3.S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess)
with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar:
tar.extractall(local_models_dir)
# Compress local_models_dir to a tar.gz file
model_tar = os.path.join(tmpdir, "models.tar.gz")
with tarfile.open(model_tar, "w:gz") as tar:
tar.add(local_models_dir, arcname=".")
# Upload the new tarfile containing all models to s3
# Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload
model_tar_s3 = lazy_sagemaker().s3.S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess)
# sanity check
assert s3_models_dir + "models.tar.gz" in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
return model_tar_s3
def create_endpoint(
self,
arn: str,
endpoint_name: str,
s3_models_dir: Optional[str] = None,
instance_type: str = "ml.g4dn.xlarge",
n_instances: int = 1,
recreate: bool = False,
role: Optional[str] = None,
) -> None:
"""Creates and deploys a SageMaker endpoint.
Args:
arn (str): The product ARN. Refers to a ready-to-use model (model package) or a fine-tuned model
(algorithm).
endpoint_name (str): The name of the endpoint.
s3_models_dir (str, optional): S3 URI pointing to the folder containing fine-tuned models. Defaults to None.
instance_type (str, optional): The EC2 instance type to deploy the endpoint to. Defaults to "ml.g4dn.xlarge".
n_instances (int, optional): Number of endpoint instances. Defaults to 1.
recreate (bool, optional): Force re-creation of endpoint if it already exists. Defaults to False.
role (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role()
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
"""
self._require_sagemaker()
# First, check if endpoint already exists
if self._does_endpoint_exist(endpoint_name):
if recreate:
self.connect_to_endpoint(endpoint_name)
self.delete_endpoint()
else:
raise CohereError(f"Endpoint {endpoint_name} already exists and recreate={recreate}.")
kwargs = {}
model_data = None
validation_params = dict()
useBoto = False
if s3_models_dir is not None:
# If s3_models_dir is given, we assume to have custom fine-tuned models -> Algorithm
kwargs["algorithm_arn"] = arn
model_data = self._s3_models_dir_to_tarfile(s3_models_dir)
else:
# If no s3_models_dir is given, we assume to use a pre-trained model -> ModelPackage
kwargs["model_package_arn"] = arn
# For now only non-finetuned models can use these timeouts
validation_params = dict(
model_data_download_timeout=2400,
container_startup_health_check_timeout=2400
)
useBoto = True
# Out of precaution, check if there is an endpoint config and delete it if that's the case
# Otherwise it might block deployment
try:
self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
except lazy_botocore().ClientError:
pass
try:
self._service_client.delete_model(ModelName=endpoint_name)
except lazy_botocore().ClientError:
pass
if role is None:
if useBoto:
accountID = lazy_sagemaker().account_id()
role = f"arn:aws:iam::{accountID}:role/ServiceRoleSagemaker"
else:
try:
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"
# deploy fine-tuned model using sagemaker SDK
if s3_models_dir is not None:
model = lazy_sagemaker().ModelPackage(
role=role,
model_data=model_data,
sagemaker_session=self._sess, # makes sure the right region is used
**kwargs
)
try:
model.deploy(
n_instances,
instance_type,
endpoint_name=endpoint_name,
**validation_params
)
except lazy_botocore().ParamValidationError:
# For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
model.deploy(n_instances, instance_type, endpoint_name=endpoint_name)
else:
# deploy pre-trained model using boto to add InferenceAmiVersion
self._service_client.create_model(
ModelName=endpoint_name,
ExecutionRoleArn=role,
EnableNetworkIsolation=True,
PrimaryContainer={
'ModelPackageName': arn,
},
)
self._service_client.create_endpoint_config(
EndpointConfigName=endpoint_name,
ProductionVariants=[
{
'VariantName': 'AllTraffic',
'ModelName': endpoint_name,
'InstanceType': instance_type,
'InitialInstanceCount': n_instances,
'InferenceAmiVersion': 'al2-ami-sagemaker-inference-gpu-2'
},
],
)
self._service_client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_name,
)
waiter = self._service_client.get_waiter('endpoint_in_service')
try:
print(f"Waiting for endpoint {endpoint_name} to be in service...")
waiter.wait(
EndpointName=endpoint_name,
WaiterConfig={
'Delay': 30,
'MaxAttempts': 80
}
)
except Exception as e:
raise CohereError(f"Failed to create endpoint: {e}")
self.connect_to_endpoint(endpoint_name)
def chat(
self,
message: str,
stream: Optional[bool] = False,
preamble: Optional[str] = None,
chat_history: Optional[List[Dict[str, Any]]] = None,
# should only be passed for stacked finetune deployment
model: Optional[str] = None,
# should only be passed for Bedrock mode; ignored otherwise
model_id: Optional[str] = None,
temperature: Optional[float] = None,
p: Optional[float] = None,
k: Optional[float] = None,
max_tokens: Optional[int] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[List[Dict[str, Any]]] = None,
prompt_truncation: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_results: Optional[List[Dict[str, Any]]] = None,
raw_prompting: Optional[bool] = False,
return_prompt: Optional[bool] = False,
variant: Optional[str] = None,
) -> Union[Chat, StreamingChat]:
"""Returns a Chat object with the query reply.
Args:
message (str): The message to send to the chatbot.
stream (bool): Return streaming tokens.
preamble (str): (Optional) A string to override the preamble.
chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.
model (str): (Optional) The model to use for generating the response. Should only be passed for stacked finetune deployment.
model_id (str): (Optional) The model to use for generating the response. Should only be passed for Bedrock mode; ignored otherwise.
temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response.
p (float): (Optional) The nucleus sampling probability.
k (float): (Optional) The top-k sampling probability.
max_tokens (int): (Optional) The max tokens generated for the next reply.
search_queries_only (bool): (Optional) When true, the response will only contain a list of generated `search_queries`, no reply from the model to the user's message will be generated.
documents (List[Dict[str, str]]): (Optional) Documents to use to generate grounded response with citations. Example:
documents=[
{
"id": "national_geographic_everest",
"title": "Height of Mount Everest",
"snippet": "The height of Mount Everest is 29,035 feet",
"url": "https://education.nationalgeographic.org/resource/mount-everest/",
},
{
"id": "national_geographic_mariana",
"title": "Depth of the Mariana Trench",
"snippet": "The depth of the Mariana Trench is 36,070 feet",
"url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth",
},
],
prompt_truncation (str) (Optional): Defaults to `OFF`. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be raised.
Returns:
a Chat object if stream=False, or a StreamingChat object if stream=True
Examples:
A simple chat message:
>>> res = co.chat(message="Hey! How are you doing today?")
>>> print(res.text)
Streaming chat:
>>> res = co.chat(
>>> message="Hey! How are you doing today?",
>>> stream=True)
>>> for token in res:
>>> print(token)
Stateless chat with chat history:
>>> res = co.chat(
>>> chat_history=[
>>> {'role': 'User', message': 'Hey! How are you doing today?'},
>>> {'role': 'Chatbot', message': 'I am doing great! How can I help you?'},
>>> message="Tell me a joke!",
>>> ])
>>> print(res.text)
Chat message with documents to use to generate the response:
>>> res = co.chat(
>>> "How deep in the Mariana Trench",
>>> documents=[
>>> {
>>> "id": "national_geographic_everest",
>>> "title": "Height of Mount Everest",
>>> "snippet": "The height of Mount Everest is 29,035 feet",
>>> "url": "https://education.nationalgeographic.org/resource/mount-everest/",
>>> },
>>> {
>>> "id": "national_geographic_mariana",
>>> "title": "Depth of the Mariana Trench",
>>> "snippet": "The depth of the Mariana Trench is 36,070 feet",
>>> "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth",
>>> },
>>> ])
>>> print(res.text)
>>> print(res.citations)
>>> print(res.documents)
Generate search queries for fetching documents to use in chat:
>>> res = co.chat(
>>> "What is the height of Mount Everest?",
>>> search_queries_only=True)
>>> if res.is_search_required:
>>> print(res.search_queries)
"""
if self.mode == Mode.SAGEMAKER and self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
json_params = {
"model": model,
"message": message,
"chat_history": chat_history,
"preamble": preamble,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
"p": p,
"k": k,
"tools": tools,
"tool_results": tool_results,
"search_queries_only": search_queries_only,
"documents": documents,
"raw_prompting": raw_prompting,
"return_prompt": return_prompt,
"prompt_truncation": prompt_truncation
}
for key, value in list(json_params.items()):
if value is None:
del json_params[key]
if self.mode == Mode.SAGEMAKER:
return self._sagemaker_chat(json_params, variant)
elif self.mode == Mode.BEDROCK:
return self._bedrock_chat(json_params, model_id)
else:
raise CohereError("Unsupported mode")
def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) :
json_body = json.dumps(json_params)
params = {
'EndpointName': self._endpoint_name,
'ContentType': 'application/json',
'Body': json_body,
}
if variant:
params['TargetVariant'] = variant
try:
if json_params['stream']:
result = self._client.invoke_endpoint_with_response_stream(
**params)
return StreamingChat(result['Body'], self.mode)
else:
result = self._client.invoke_endpoint(**params)
return Chat.from_dict(json.loads(result['Body'].read().decode()))
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) :
if not model_id:
raise CohereError("must supply model_id arg when calling bedrock")
if json_params['stream']:
stream = json_params['stream']
else:
stream = False
# Bedrock does not expect the stream key to be present in the body, use invoke_model_with_response_stream to indicate stream mode
del json_params['stream']
json_body = json.dumps(json_params)
params = {
'body': json_body,
'modelId': model_id,
}
try:
if stream:
result = self._client.invoke_model_with_response_stream(
**params)
return StreamingChat(result['body'], self.mode)
else:
result = self._client.invoke_model(**params)
return Chat.from_dict(
json.loads(result['body'].read().decode()))
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
def generate(
self,
prompt: str,
# should only be passed for stacked finetune deployment
model: Optional[str] = None,
# should only be passed for Bedrock mode; ignored otherwise
model_id: Optional[str] = None,
# requires DB with presets
# preset: str = None,
num_generations: int = 1,
max_tokens: int = 400,
temperature: float = 1.0,
k: int = 0,
p: float = 0.75,
stop_sequences: Optional[List[str]] = None,
return_likelihoods: Optional[str] = None,
truncate: Optional[str] = None,
variant: Optional[str] = None,
stream: Optional[bool] = True,
) -> Union[Generations, StreamingGenerations]:
if self.mode == Mode.SAGEMAKER and self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
json_params = {
'model': model,
'prompt': prompt,
'max_tokens': max_tokens,
'temperature': temperature,
'k': k,
'p': p,
'stop_sequences': stop_sequences,
'return_likelihoods': return_likelihoods,
'truncate': truncate,
'stream': stream,
}
for key, value in list(json_params.items()):
if value is None:
del json_params[key]
if self.mode == Mode.SAGEMAKER:
# TODO: Bedrock should support this param too
json_params['num_generations'] = num_generations
return self._sagemaker_generations(json_params, variant)
elif self.mode == Mode.BEDROCK:
return self._bedrock_generations(json_params, model_id)
else:
raise CohereError("Unsupported mode")
def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) :
json_body = json.dumps(json_params)
params = {
'EndpointName': self._endpoint_name,
'ContentType': 'application/json',
'Body': json_body,
}
if variant:
params['TargetVariant'] = variant
try:
if json_params['stream']:
result = self._client.invoke_endpoint_with_response_stream(
**params)
return StreamingGenerations(result['Body'], self.mode)
else:
result = self._client.invoke_endpoint(**params)
return Generations(
json.loads(result['Body'].read().decode())['generations'])
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) :
if not model_id:
raise CohereError("must supply model_id arg when calling bedrock")
json_body = json.dumps(json_params)
params = {
'body': json_body,
'modelId': model_id,
}
try:
if json_params['stream']:
result = self._client.invoke_model_with_response_stream(
**params)
return StreamingGenerations(result['body'], self.mode)
else:
result = self._client.invoke_model(**params)
return Generations(
json.loads(result['body'].read().decode())['generations'])
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
def embed(
self,
texts: List[str],
truncate: Optional[str] = None,
variant: Optional[str] = None,
input_type: Optional[str] = None,
model_id: Optional[str] = None,
output_dimension: Optional[int] = None,
embedding_types: Optional[List[str]] = None,
) -> Union[Embeddings, Dict[str, List]]:
json_params = {
'texts': texts,
'truncate': truncate,
"input_type": input_type,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
}
for key, value in list(json_params.items()):
if value is None:
del json_params[key]
if self.mode == Mode.SAGEMAKER:
return self._sagemaker_embed(json_params, variant)
elif self.mode == Mode.BEDROCK:
return self._bedrock_embed(json_params, model_id)
else:
raise CohereError("Unsupported mode")
def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
json_body = json.dumps(json_params)
params = {
'EndpointName': self._endpoint_name,
'ContentType': 'application/json',
'Body': json_body,
}
if variant:
params['TargetVariant'] = variant
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)
def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
if not model_id:
raise CohereError("must supply model_id arg when calling bedrock")
json_body = json.dumps(json_params)
params = {
'body': json_body,
'modelId': model_id,
}
try:
result = self._client.invoke_model(**params)
response = json.loads(result['body'].read().decode())
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)
def rerank(self,
query: str,
documents: Union[List[str], List[Dict[str, Any]]],
top_n: Optional[int] = None,
variant: Optional[str] = None,
max_chunks_per_doc: Optional[int] = None,
rank_fields: Optional[List[str]] = None) -> Reranking:
"""Returns an ordered list of documents oridered by their relevance to the provided query
Args:
query (str): The search query
documents (list[str], list[dict]): The documents to rerank
top_n (int): (optional) The number of results to return, defaults to return all results
max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document
rank_fields (list[str]): (optional) The fields used for reranking. This parameter is only supported for rerank v3 models
"""
if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
parsed_docs = []
for doc in documents:
if isinstance(doc, str):
parsed_docs.append({'text': doc})
elif isinstance(doc, dict):
parsed_docs.append(doc)
else:
raise CohereError(
message='invalid format for documents, must be a list of strings or dicts')
json_params = {
"query": query,
"documents": parsed_docs,
"top_n": top_n,
"return_documents": False,
"max_chunks_per_doc" : max_chunks_per_doc,
"rank_fields": rank_fields
}
json_body = json.dumps(json_params)
params = {
'EndpointName': self._endpoint_name,
'ContentType': 'application/json',
'Body': json_body,
}
if variant is not None:
params['TargetVariant'] = variant
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
reranking = Reranking(response)
for rank in reranking.results:
rank.document = parsed_docs[rank.index]
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
return reranking
def classify(self, input: List[str], name: str) -> Classifications:
if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
json_params = {"texts": input, "model_id": name}
json_body = json.dumps(json_params)
params = {
"EndpointName": self._endpoint_name,
"ContentType": "application/json",
"Body": json_body,
}
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result["Body"].read().decode())
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
return Classifications([Classification(classification) for classification in response])
def create_finetune(
self,
name: str,
train_data: str,
s3_models_dir: str,
arn: Optional[str] = None,
eval_data: Optional[str] = None,
instance_type: str = "ml.g4dn.xlarge",
training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters
role: Optional[str] = None,
base_model_id: Optional[str] = None,
) -> Optional[str]:
"""Creates a fine-tuning job and returns an optional fintune job ID.
Args:
name (str): The name to give to the fine-tuned model.
train_data (str): An S3 path pointing to the training data.
s3_models_dir (str): An S3 path pointing to the directory where the fine-tuned model will be saved.
arn (str, optional): The product ARN of the fine-tuning package. Required in Sagemaker mode and ignored otherwise
eval_data (str, optional): An S3 path pointing to the eval data. Defaults to None.
instance_type (str, optional): The EC2 instance type to use for training. Defaults to "ml.g4dn.xlarge".
training_parameters (Dict[str, Any], optional): Additional training parameters. Defaults to {}.
role (str, optional): The IAM role to use for the endpoint.
In Bedrock this mode is required and is used to access s3 input and output data.
If not provided in sagemaker, sagemaker.get_execution_role()will be used to get the role.
This should work when one uses the client inside SageMaker. If this errors
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
base_model_id (str, optional): The ID of the Bedrock base model to finetune with. Required in Bedrock mode and ignored otherwise.
"""
assert name != "model", "name cannot be 'model'"
if self.mode == Mode.BEDROCK:
return self._bedrock_create_finetune(name=name, train_data=train_data, s3_models_dir=s3_models_dir, base_model=base_model_id, eval_data=eval_data, training_parameters=training_parameters, role=role)
s3_models_dir = s3_models_dir.rstrip("/") + "/"
if role is None:
try:
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"
training_parameters.update({"name": name})
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
instance_type=instance_type,
sagemaker_session=self._sess,
output_path=s3_models_dir,
hyperparameters=training_parameters,
)
inputs = {}
if not train_data.startswith("s3:"):
raise ValueError("train_data must point to an S3 location.")
inputs["training"] = train_data
if eval_data is not None:
if not eval_data.startswith("s3:"):
raise ValueError("eval_data must point to an S3 location.")
inputs["evaluation"] = eval_data
estimator.fit(inputs=inputs)
job_name = estimator.latest_training_job.name
current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz"
s3_resource = lazy_boto3().resource("s3")
# Copy new model to root of output_model_dir
bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath)
_, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_models_dir}{name}.tar.gz")
s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key})
# Delete old dir
bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(s3_models_dir + job_name)
s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete()
def export_finetune(
self,
name: str,
s3_checkpoint_dir: str,
s3_output_dir: str,
arn: str,
instance_type: str = "ml.p4de.24xlarge",
role: Optional[str] = None,
) -> None:
"""Export the merged weights to the TensorRT-LLM inference engine.
Args:
name (str): The name used while writing the exported model to the output directory.
s3_checkpoint_dir (str): An S3 path pointing to the directory of the model checkpoint (merged weights).
s3_output_dir (str): An S3 path pointing to the directory where the TensorRT-LLM engine will be saved.
arn (str): The product ARN of the bring your own finetuning algorithm.
instance_type (str, optional): The EC2 instance type to use for export. Defaults to "ml.p4de.24xlarge".
role (str, optional): The IAM role to use for export.
If not provided, sagemaker.get_execution_role() will be used to get the role.
This should work when one uses the client inside SageMaker. If this errors out,
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
"""
self._require_sagemaker()
if name == "model":
raise ValueError("name cannot be 'model'")
s3_output_dir = s3_output_dir.rstrip("/") + "/"
if role is None:
try:
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"
export_parameters = {"name": name}
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
instance_type=instance_type,
sagemaker_session=self._sess,
output_path=s3_output_dir,
hyperparameters=export_parameters,
)
if not s3_checkpoint_dir.startswith("s3:"):
raise ValueError("s3_checkpoint_dir must point to an S3 location.")
inputs = {"checkpoint": s3_checkpoint_dir}
estimator.fit(inputs=inputs)
job_name = estimator.latest_training_job.name
current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz"
s3_resource = lazy_boto3().resource("s3")
# Copy the exported TensorRT-LLM engine to the root of s3_output_dir
bucket, old_key = lazy_sagemaker().s3.parse_s3_url(current_filepath)
_, new_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{name}.tar.gz")
s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key})
# Delete the old S3 directory
bucket, old_short_key = lazy_sagemaker().s3.parse_s3_url(f"{s3_output_dir}{job_name}")
s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete()
def wait_for_finetune_job(self, job_id: str, timeout: int = 2*60*60) -> str:
"""Waits for a finetune job to complete and returns a model arn if complete. Throws an exception if timeout occurs or if job does not complete successfully
Args:
job_id (str): The arn of the model customization job
timeout(int, optional): Timeout in seconds
"""
end = time.time() + timeout
while True:
customization_job = self._service_client.get_model_customization_job(jobIdentifier=job_id)
job_status = customization_job["status"]
if job_status in ["Completed", "Failed", "Stopped"]:
break
if time.time() > end:
raise CohereError("could not complete finetune within timeout")
time.sleep(10)
if job_status != "Completed":
raise CohereError(f"finetune did not finish successfuly, ended with {job_status} status")
return customization_job["outputModelArn"]
def provision_throughput(
self,
model_id: str,
name: str,
model_units: int,
commitment_duration: Optional[str] = None
) -> str:
"""Returns the provisined model arn
Args:
model_id (str): The ID or ARN of the model to provision
name (str): Name of the provisioned throughput model
model_units (int): Number of units to provision
commitment_duration (str, optional): Commitment duration, one of ("OneMonth", "SixMonths"), defaults to no commitment if unspecified
"""
if self.mode != Mode.BEDROCK:
raise ValueError("can only provision throughput in bedrock")
kwargs = {}
if commitment_duration:
kwargs["commitmentDuration"] = commitment_duration
response = self._service_client.create_provisioned_model_throughput(
provisionedModelName=name,
modelId=model_id,
modelUnits=model_units,
**kwargs
)
return response["provisionedModelArn"]
def _bedrock_create_finetune(
self,
name: str,
train_data: str,
s3_models_dir: str,
base_model: str,
eval_data: Optional[str] = None,
training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters
role: Optional[str] = None,
) -> None:
if not name:
raise ValueError("name must not be empty")
if not role:
raise ValueError("must provide a role ARN for bedrock finetuning (https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-iam-role.html)")
if not train_data.startswith("s3:"):
raise ValueError("train_data must point to an S3 location.")
if eval_data:
if not eval_data.startswith("s3:"):
raise ValueError("eval_data must point to an S3 location.")
validationDataConfig = {
"validators": [{
"s3Uri": eval_data
}]
}
job_name = f"{name}-job"
customization_job = self._service_client.create_model_customization_job(
jobName=job_name,
customModelName=name,
roleArn=role,
baseModelIdentifier=base_model,
trainingDataConfig={"s3Uri": train_data},
validationDataConfig=validationDataConfig,
outputDataConfig={"s3Uri": s3_models_dir},
hyperParameters=training_parameters
)
return customization_job["jobArn"]
def summarize(
self,
text: str,
length: Optional[str] = "auto",
format_: Optional[str] = "auto",
# Only summarize-xlarge is supported on Sagemaker
# model: Optional[str] = "summarize-xlarge",
extractiveness: Optional[str] = "auto",
temperature: Optional[float] = 0.3,
additional_command: Optional[str] = "",
variant: Optional[str] = None
) -> Summary:
self._require_sagemaker()
if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
"Run connect_to_endpoint() first.")
json_params = {
'text': text,
'length': length,
'format': format_,
'extractiveness': extractiveness,
'temperature': temperature,
'additional_command': additional_command,
}
for key, value in list(json_params.items()):
if value is None:
del json_params[key]
json_body = json.dumps(json_params)
params = {
'EndpointName': self._endpoint_name,
'ContentType': 'application/json',
'Body': json_body,
}
if variant is not None:
params['TargetVariant'] = variant
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
summary = Summary(response)
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))
return summary
def delete_endpoint(self) -> None:
self._require_sagemaker()
if self._endpoint_name is None:
raise CohereError("No endpoint connected.")
try:
self._service_client.delete_endpoint(EndpointName=self._endpoint_name)
except:
print("Endpoint not found, skipping deletion.")
try:
self._service_client.delete_endpoint_config(EndpointConfigName=self._endpoint_name)
except:
print("Endpoint config not found, skipping deletion.")
def close(self) -> None:
try:
self._client.close()
self._service_client.close()
except AttributeError:
print("SageMaker client could not be closed. This might be because you are using an old version of SageMaker.")
raise
================================================
FILE: src/cohere/manually_maintained/cohere_aws/embeddings.py
================================================
from .response import CohereObject
from typing import Iterator, List
class Embedding(CohereObject):
def __init__(self, embedding: List[float]) -> None:
self.embedding = embedding
def __iter__(self) -> Iterator:
return iter(self.embedding)
def __len__(self) -> int:
return len(self.embedding)
class Embeddings(CohereObject):
def __init__(self, embeddings: List[Embedding]) -> None:
self.embeddings = embeddings
def __iter__(self) -> Iterator:
return iter(self.embeddings)
def __len__(self) -> int:
return len(self.embeddings)
================================================
FILE: src/cohere/manually_maintained/cohere_aws/error.py
================================================
class CohereError(Exception):
def __init__(
self,
message=None,
http_status=None,
headers=None,
) -> None:
super(CohereError, self).__init__(message)
self.message = message
self.http_status = http_status
self.headers = headers or {}
def __str__(self) -> str:
msg = self.message or ''
return msg
def __repr__(self) -> str:
return '%s(message=%r, http_status=%r)' % (
self.__class__.__name__,
self.message,
self.http_status,
)
================================================
FILE: src/cohere/manually_maintained/cohere_aws/generation.py
================================================
from .response import CohereObject
from .mode import Mode
from typing import List, Optional, NamedTuple, Generator, Dict, Any
import json
class TokenLikelihood(CohereObject):
def __init__(self, token: str, likelihood: float) -> None:
self.token = token
self.likelihood = likelihood
class Generation(CohereObject):
def __init__(self,
text: str,
token_likelihoods: List[TokenLikelihood]) -> None:
self.text = text
self.token_likelihoods = token_likelihoods
class Generations(CohereObject):
def __init__(self,
generations: List[Generation]) -> None:
self.generations = generations
self.iterator = iter(generations)
@classmethod
def from_dict(cls, response: Dict[str, Any]) -> List[Generation]:
generations: List[Generation] = []
for gen in response['generations']:
token_likelihoods = None
if 'token_likelihoods' in gen:
token_likelihoods = []
for likelihoods in gen['token_likelihoods']:
if 'likelihood' in likelihoods:
token_likelihood = likelihoods['likelihood']
else:
token_likelihood = None
token_likelihoods.append(TokenLikelihood(
likelihoods['token'], token_likelihood))
generations.append(Generation(gen['text'], token_likelihoods))
return cls(generations)
def __iter__(self) -> iter:
return self.iterator
def __next__(self) -> next:
return next(self.iterator)
StreamingText = NamedTuple("StreamingText",
[("index", Optional[int]),
("text", str),
("is_finished", bool)])
class StreamingGenerations(CohereObject):
def __init__(self, stream, mode):
self.stream = stream
self.id = None
self.generations = None
self.finish_reason = None
self.bytes = bytearray()
if mode == Mode.SAGEMAKER:
self.payload_key = "PayloadPart"
self.bytes_key = "Bytes"
elif mode == Mode.BEDROCK:
self.payload_key = "chunk"
self.bytes_key = "bytes"
else:
raise CohereError("Unsupported mode")
def _make_response_item(self, streaming_item) -> Optional[StreamingText]:
is_finished = streaming_item.get("is_finished")
if not is_finished:
index = streaming_item.get("index", 0)
text = streaming_item.get("text")
if text is None:
return None
return StreamingText(
text=text, is_finished=is_finished, index=index)
self.finish_reason = streaming_item.get("finish_reason")
generation_response = streaming_item.get("response")
if generation_response is None:
return None
self.id = generation_response.get("id")
self.generations = Generations.from_dict(generation_response)
return None
def __iter__(self) -> Generator[StreamingText, None, None]:
for payload in self.stream:
self.bytes.extend(payload[self.payload_key][self.bytes_key])
try:
item = self._make_response_item(json.loads(self.bytes))
except json.decoder.JSONDecodeError:
# payload contained only a partion JSON object
continue
self.bytes = bytearray()
if item is not None:
yield item
================================================
FILE: src/cohere/manually_maintained/cohere_aws/mode.py
================================================
from enum import Enum
class Mode(Enum):
SAGEMAKER = 1
BEDROCK = 2
================================================
FILE: src/cohere/manually_maintained/cohere_aws/rerank.py
================================================
from typing import Any, Dict, Iterator, List, NamedTuple, Optional
from .response import CohereObject
RerankDocument = NamedTuple("Document", [("text", str)])
RerankDocument.__doc__ = """
Returned by co.rerank,
dict which always contains text but can also contain aribitrary fields
"""
class RerankResult(CohereObject):
def __init__(self,
document: Dict[str, Any] = None,
index: int = None,
relevance_score: float = None,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.document = document
self.index = index
self.relevance_score = relevance_score
def __repr__(self) -> str:
score = self.relevance_score
index = self.index
if self.document is None:
return f"RerankResult"
elif 'text' in self.document:
text = self.document['text']
return f"RerankResult"
else:
return f"RerankResult"
class Reranking(CohereObject):
def __init__(self,
response: Optional[Dict[str, Any]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
assert response is not None
self.results = self._results(response)
def _results(self, response: Dict[str, Any]) -> List[RerankResult]:
results = []
for res in response['results']:
if 'document' in res.keys():
results.append(
RerankResult(res['document'], res['index'], res['relevance_score']))
else:
results.append(
RerankResult(index=res['index'], relevance_score=res['relevance_score']))
return results
def __str__(self) -> str:
return str(self.results)
def __repr__(self) -> str:
return self.results.__repr__()
def __iter__(self) -> Iterator:
return iter(self.results)
def __getitem__(self, index) -> RerankResult:
return self.results[index]
================================================
FILE: src/cohere/manually_maintained/cohere_aws/response.py
================================================
class CohereObject():
def __repr__(self) -> str:
contents = ''
exclude_list = ['iterator']
for k in self.__dict__.keys():
if k not in exclude_list:
contents += f'\t{k}: {self.__dict__[k]}\n'
output = f'cohere.{type(self).__name__} {{\n{contents}}}'
return output
================================================
FILE: src/cohere/manually_maintained/cohere_aws/summary.py
================================================
from .error import CohereError
from .response import CohereObject
from typing import Any, Dict, Optional
class Summary(CohereObject):
def __init__(self,
response: Optional[Dict[str, Any]] = None) -> None:
assert response is not None
if not response["summary"]:
raise CohereError("Response lacks a summary")
self.result = response["summary"]
def __str__(self) -> str:
return self.result
================================================
FILE: src/cohere/manually_maintained/lazy_aws_deps.py
================================================
warning = "AWS dependencies are not installed. Please install boto3, botocore, and sagemaker."
def lazy_sagemaker():
try:
import sagemaker as sage # type: ignore
return sage
except ImportError:
raise ImportError(warning)
def lazy_boto3():
try:
import boto3 # type: ignore
return boto3
except ImportError:
raise ImportError(warning)
def lazy_botocore():
try:
import botocore # type: ignore
return botocore
except ImportError:
raise ImportError(warning)
================================================
FILE: src/cohere/manually_maintained/lazy_oci_deps.py
================================================
"""Lazy loading for optional OCI SDK dependency."""
from typing import Any
OCI_INSTALLATION_MESSAGE = """
The OCI SDK is required to use OciClient or OciClientV2.
Install it with:
pip install oci
Or with the optional dependency group:
pip install cohere[oci]
"""
def lazy_oci() -> Any:
"""
Lazily import the OCI SDK.
Returns:
The oci module
Raises:
ImportError: If the OCI SDK is not installed
"""
try:
import oci # type: ignore[import-untyped, import-not-found]
return oci
except ImportError:
raise ImportError(OCI_INSTALLATION_MESSAGE)
================================================
FILE: src/cohere/manually_maintained/streaming_embed.py
================================================
"""Utilities for streaming embed responses without loading all embeddings into memory."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union
@dataclass
class StreamedEmbedding:
"""A single embedding yielded incrementally from embed_stream()."""
index: int
embedding: Union[List[float], List[int]]
embedding_type: str
text: Optional[str] = None
def extract_embeddings_from_response(
response_data: dict,
batch_texts: List[str],
global_offset: int = 0,
) -> Iterator[StreamedEmbedding]:
"""
Extract individual embeddings from a Cohere embed response dict.
Works for both V1 (embeddings_floats / embeddings_by_type) and V2 response formats.
Args:
response_data: Parsed JSON response from embed endpoint
batch_texts: The texts that were embedded in this batch
global_offset: Starting index for this batch within the full dataset
Yields:
StreamedEmbedding objects
"""
response_type = response_data.get("response_type", "")
if response_type == "embeddings_floats":
embeddings = response_data.get("embeddings", [])
for i, embedding in enumerate(embeddings):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type="float",
text=batch_texts[i] if i < len(batch_texts) else None,
)
elif response_type == "embeddings_by_type":
embeddings_obj = response_data.get("embeddings", {})
for emb_type, embeddings_list in embeddings_obj.items():
type_name = emb_type.rstrip("_")
if isinstance(embeddings_list, list):
for i, embedding in enumerate(embeddings_list):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type=type_name,
text=batch_texts[i] if i < len(batch_texts) else None,
)
else:
# V2 format: embeddings is a dict with type keys directly
embeddings_obj = response_data.get("embeddings", {})
if isinstance(embeddings_obj, dict):
for emb_type, embeddings_list in embeddings_obj.items():
type_name = emb_type.rstrip("_")
if isinstance(embeddings_list, list):
for i, embedding in enumerate(embeddings_list):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type=type_name,
text=batch_texts[i] if i < len(batch_texts) else None,
)
================================================
FILE: src/cohere/manually_maintained/tokenizers.py
================================================
import asyncio
import logging
import typing
import requests
from tokenizers import Tokenizer # type: ignore
if typing.TYPE_CHECKING:
from cohere.client import AsyncClient, Client
TOKENIZER_CACHE_KEY = "tokenizers"
logger = logging.getLogger(__name__)
def tokenizer_cache_key(model: str) -> str:
return f"{TOKENIZER_CACHE_KEY}:{model}"
def get_hf_tokenizer(co: "Client", model: str) -> Tokenizer:
"""Returns a HF tokenizer from a given tokenizer config URL."""
tokenizer = co._cache_get(tokenizer_cache_key(model))
if tokenizer is not None:
return tokenizer
tokenizer_url = co.models.get(model).tokenizer_url
if not tokenizer_url:
raise ValueError(f"No tokenizer URL found for model {model}")
# Print the size of the tokenizer config before downloading it.
try:
size = _get_tokenizer_config_size(tokenizer_url)
logger.info(f"Downloading tokenizer for model {model}. Size is {size} MBs.")
except Exception as e:
# Skip the size logging, this is not critical.
logger.warn(f"Failed to get the size of the tokenizer config: {e}")
response = requests.get(tokenizer_url)
tokenizer = Tokenizer.from_str(response.text)
co._cache_set(tokenizer_cache_key(model), tokenizer)
return tokenizer
def local_tokenize(co: "Client", model: str, text: str) -> typing.List[int]:
"""Encodes a given text using a local tokenizer."""
tokenizer = get_hf_tokenizer(co, model)
return tokenizer.encode(text, add_special_tokens=False).ids
def local_detokenize(co: "Client", model: str, tokens: typing.Sequence[int]) -> str:
"""Decodes a given list of tokens using a local tokenizer."""
tokenizer = get_hf_tokenizer(co, model)
return tokenizer.decode(tokens)
async def async_get_hf_tokenizer(co: "AsyncClient", model: str) -> Tokenizer:
"""Returns a HF tokenizer from a given tokenizer config URL."""
tokenizer = co._cache_get(tokenizer_cache_key(model))
if tokenizer is not None:
return tokenizer
tokenizer_url = (await co.models.get(model)).tokenizer_url
if not tokenizer_url:
raise ValueError(f"No tokenizer URL found for model {model}")
# Print the size of the tokenizer config before downloading it.
try:
size = _get_tokenizer_config_size(tokenizer_url)
logger.info(f"Downloading tokenizer for model {model}. Size is {size} MBs.")
except Exception as e:
# Skip the size logging, this is not critical.
logger.warn(f"Failed to get the size of the tokenizer config: {e}")
response = await asyncio.get_event_loop().run_in_executor(None, requests.get, tokenizer_url)
tokenizer = Tokenizer.from_str(response.text)
co._cache_set(tokenizer_cache_key(model), tokenizer)
return tokenizer
async def async_local_tokenize(co: "AsyncClient", model: str, text: str) -> typing.List[int]:
"""Encodes a given text using a local tokenizer."""
tokenizer = await async_get_hf_tokenizer(co, model)
return tokenizer.encode(text, add_special_tokens=False).ids
async def async_local_detokenize(co: "AsyncClient", model: str, tokens: typing.Sequence[int]) -> str:
"""Decodes a given list of tokens using a local tokenizer."""
tokenizer = await async_get_hf_tokenizer(co, model)
return tokenizer.decode(tokens)
def _get_tokenizer_config_size(tokenizer_url: str) -> float:
# Get the size of the tokenizer config before downloading it.
# Content-Length is not always present in the headers (if transfer-encoding: chunked).
head_response = requests.head(tokenizer_url)
size = None
for header in ["x-goog-stored-content-length", "Content-Length"]:
size = head_response.headers.get(header)
if size:
break
return round(int(typing.cast(int, size)) / 1024 / 1024, 2)
================================================
FILE: src/cohere/models/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
================================================
FILE: src/cohere/models/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from ..types.compatible_endpoint import CompatibleEndpoint
from ..types.get_model_response import GetModelResponse
from ..types.list_models_response import ListModelsResponse
from .raw_client import AsyncRawModelsClient, RawModelsClient
class ModelsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawModelsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawModelsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawModelsClient
"""
return self._raw_client
def get(self, model: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetModelResponse:
"""
Returns the details of a model, provided its name.
Parameters
----------
model : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetModelResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.models.get(
model="command-a-03-2025",
)
"""
_response = self._raw_client.get(model, request_options=request_options)
return _response.data
def list(
self,
*,
page_size: typing.Optional[float] = None,
page_token: typing.Optional[str] = None,
endpoint: typing.Optional[CompatibleEndpoint] = None,
default_only: typing.Optional[bool] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListModelsResponse:
"""
Returns a list of models available for use.
Parameters
----------
page_size : typing.Optional[float]
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
page_token : typing.Optional[str]
Page token provided in the `next_page_token` field of a previous response.
endpoint : typing.Optional[CompatibleEndpoint]
When provided, filters the list of models to only those that are compatible with the specified endpoint.
default_only : typing.Optional[bool]
When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListModelsResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.models.list(
page_size=1.1,
page_token="page_token",
endpoint="chat",
default_only=True,
)
"""
_response = self._raw_client.list(
page_size=page_size,
page_token=page_token,
endpoint=endpoint,
default_only=default_only,
request_options=request_options,
)
return _response.data
class AsyncModelsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawModelsClient(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawModelsClient:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawModelsClient
"""
return self._raw_client
async def get(self, model: str, *, request_options: typing.Optional[RequestOptions] = None) -> GetModelResponse:
"""
Returns the details of a model, provided its name.
Parameters
----------
model : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
GetModelResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.models.get(
model="command-a-03-2025",
)
asyncio.run(main())
"""
_response = await self._raw_client.get(model, request_options=request_options)
return _response.data
async def list(
self,
*,
page_size: typing.Optional[float] = None,
page_token: typing.Optional[str] = None,
endpoint: typing.Optional[CompatibleEndpoint] = None,
default_only: typing.Optional[bool] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ListModelsResponse:
"""
Returns a list of models available for use.
Parameters
----------
page_size : typing.Optional[float]
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
page_token : typing.Optional[str]
Page token provided in the `next_page_token` field of a previous response.
endpoint : typing.Optional[CompatibleEndpoint]
When provided, filters the list of models to only those that are compatible with the specified endpoint.
default_only : typing.Optional[bool]
When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ListModelsResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.models.list(
page_size=1.1,
page_token="page_token",
endpoint="chat",
default_only=True,
)
asyncio.run(main())
"""
_response = await self._raw_client.list(
page_size=page_size,
page_token=page_token,
endpoint=endpoint,
default_only=default_only,
request_options=request_options,
)
return _response.data
================================================
FILE: src/cohere/models/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from json.decoder import JSONDecodeError
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.jsonable_encoder import jsonable_encoder
from ..core.parse_error import ParsingError
from ..core.request_options import RequestOptions
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.client_closed_request_error import ClientClosedRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.gateway_timeout_error import GatewayTimeoutError
from ..errors.internal_server_error import InternalServerError
from ..errors.invalid_token_error import InvalidTokenError
from ..errors.not_found_error import NotFoundError
from ..errors.not_implemented_error import NotImplementedError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.too_many_requests_error import TooManyRequestsError
from ..errors.unauthorized_error import UnauthorizedError
from ..errors.unprocessable_entity_error import UnprocessableEntityError
from ..types.compatible_endpoint import CompatibleEndpoint
from ..types.get_model_response import GetModelResponse
from ..types.list_models_response import ListModelsResponse
from pydantic import ValidationError
class RawModelsClient:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
def get(
self, model: str, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[GetModelResponse]:
"""
Returns the details of a model, provided its name.
Parameters
----------
model : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[GetModelResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
f"v1/models/{jsonable_encoder(model)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetModelResponse,
construct_type(
type_=GetModelResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def list(
self,
*,
page_size: typing.Optional[float] = None,
page_token: typing.Optional[str] = None,
endpoint: typing.Optional[CompatibleEndpoint] = None,
default_only: typing.Optional[bool] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ListModelsResponse]:
"""
Returns a list of models available for use.
Parameters
----------
page_size : typing.Optional[float]
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
page_token : typing.Optional[str]
Page token provided in the `next_page_token` field of a previous response.
endpoint : typing.Optional[CompatibleEndpoint]
When provided, filters the list of models to only those that are compatible with the specified endpoint.
default_only : typing.Optional[bool]
When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ListModelsResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/models",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"endpoint": endpoint,
"default_only": default_only,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListModelsResponse,
construct_type(
type_=ListModelsResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawModelsClient:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
async def get(
self, model: str, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[GetModelResponse]:
"""
Returns the details of a model, provided its name.
Parameters
----------
model : str
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[GetModelResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
f"v1/models/{jsonable_encoder(model)}",
method="GET",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
GetModelResponse,
construct_type(
type_=GetModelResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def list(
self,
*,
page_size: typing.Optional[float] = None,
page_token: typing.Optional[str] = None,
endpoint: typing.Optional[CompatibleEndpoint] = None,
default_only: typing.Optional[bool] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ListModelsResponse]:
"""
Returns a list of models available for use.
Parameters
----------
page_size : typing.Optional[float]
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
page_token : typing.Optional[str]
Page token provided in the `next_page_token` field of a previous response.
endpoint : typing.Optional[CompatibleEndpoint]
When provided, filters the list of models to only those that are compatible with the specified endpoint.
default_only : typing.Optional[bool]
When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ListModelsResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/models",
method="GET",
params={
"page_size": page_size,
"page_token": page_token,
"endpoint": endpoint,
"default_only": default_only,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ListModelsResponse,
construct_type(
type_=ListModelsResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/oci_client.py
================================================
"""Oracle Cloud Infrastructure (OCI) client for Cohere API."""
import configparser
import email.utils
import json
import os
import typing
import uuid
import httpx
import requests
from .client import Client, ClientEnvironment
from .client_v2 import ClientV2
from .aws_client import Streamer
from .manually_maintained.lazy_oci_deps import lazy_oci
from httpx import URL, ByteStream
class OciClient(Client):
"""
Cohere V1 API client for Oracle Cloud Infrastructure (OCI) Generative AI service.
Use this client for V1 API models (Command R family) and embeddings.
For V2 API models (Command A family), use OciClientV2 instead.
Supported APIs on OCI:
- embed(): Full support for all embedding models
- chat(): Full support with Command-R models
- chat_stream(): Streaming chat support
Supports all authentication methods:
- Config file (default): Uses ~/.oci/config
- Session-based: Uses OCI CLI session tokens
- Direct credentials: Pass OCI credentials directly
- Instance principal: For OCI compute instances
- Resource principal: For OCI functions
Example:
```python
import cohere
client = cohere.OciClient(
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
response = client.chat(
model="command-r-08-2024",
message="Hello!",
)
print(response.text)
```
"""
def __init__(
self,
*,
oci_config_path: typing.Optional[str] = None,
oci_profile: typing.Optional[str] = None,
oci_user_id: typing.Optional[str] = None,
oci_fingerprint: typing.Optional[str] = None,
oci_tenancy_id: typing.Optional[str] = None,
oci_private_key_path: typing.Optional[str] = None,
oci_private_key_content: typing.Optional[str] = None,
auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key",
oci_region: typing.Optional[str] = None,
oci_compartment_id: str,
timeout: typing.Optional[float] = None,
):
oci_config = _load_oci_config(
auth_type=auth_type,
config_path=oci_config_path,
profile=oci_profile,
user_id=oci_user_id,
fingerprint=oci_fingerprint,
tenancy_id=oci_tenancy_id,
private_key_path=oci_private_key_path,
private_key_content=oci_private_key_content,
)
if oci_region is None:
oci_region = oci_config.get("region")
if oci_region is None:
raise ValueError("oci_region must be provided either directly or in OCI config file")
Client.__init__(
self,
base_url="https://api.cohere.com",
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
oci_config=oci_config,
oci_region=oci_region,
oci_compartment_id=oci_compartment_id,
is_v2_client=False,
),
timeout=timeout,
),
)
class OciClientV2(ClientV2):
"""
Cohere V2 API client for Oracle Cloud Infrastructure (OCI) Generative AI service.
Supported APIs on OCI:
- embed(): Full support for all embedding models (returns embeddings as dict)
- chat(): Full support with Command-A models (command-a-03-2025)
- chat_stream(): Streaming chat with proper V2 event format
Note: rerank() requires fine-tuned models deployed to dedicated endpoints.
OCI on-demand inference does not support the rerank API.
Supports all authentication methods:
- Config file (default): Uses ~/.oci/config
- Session-based: Uses OCI CLI session tokens
- Direct credentials: Pass OCI credentials directly
- Instance principal: For OCI compute instances
- Resource principal: For OCI functions
Example using config file:
```python
import cohere
client = cohere.OciClientV2(
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
response = client.embed(
model="embed-english-v3.0",
texts=["Hello world"],
input_type="search_document",
)
print(response.embeddings.float_)
response = client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "Hello!"}],
)
print(response.message)
```
Example using direct credentials:
```python
client = cohere.OciClientV2(
oci_user_id="ocid1.user.oc1...",
oci_fingerprint="xx:xx:xx:...",
oci_tenancy_id="ocid1.tenancy.oc1...",
oci_private_key_path="~/.oci/key.pem",
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
Example using instance principal:
```python
client = cohere.OciClientV2(
auth_type="instance_principal",
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1...",
)
```
"""
def __init__(
self,
*,
# Authentication - Config file (default)
oci_config_path: typing.Optional[str] = None,
oci_profile: typing.Optional[str] = None,
# Authentication - Direct credentials
oci_user_id: typing.Optional[str] = None,
oci_fingerprint: typing.Optional[str] = None,
oci_tenancy_id: typing.Optional[str] = None,
oci_private_key_path: typing.Optional[str] = None,
oci_private_key_content: typing.Optional[str] = None,
# Authentication - Instance principal
auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key",
# Required for OCI Generative AI
oci_region: typing.Optional[str] = None,
oci_compartment_id: str,
# Standard parameters
timeout: typing.Optional[float] = None,
):
# Load OCI config based on auth_type
oci_config = _load_oci_config(
auth_type=auth_type,
config_path=oci_config_path,
profile=oci_profile,
user_id=oci_user_id,
fingerprint=oci_fingerprint,
tenancy_id=oci_tenancy_id,
private_key_path=oci_private_key_path,
private_key_content=oci_private_key_content,
)
# Get region from config if not provided
if oci_region is None:
oci_region = oci_config.get("region")
if oci_region is None:
raise ValueError("oci_region must be provided either directly or in OCI config file")
# Create httpx client with OCI event hooks
ClientV2.__init__(
self,
base_url="https://api.cohere.com", # Unused, OCI URL set in hooks
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
oci_config=oci_config,
oci_region=oci_region,
oci_compartment_id=oci_compartment_id,
is_v2_client=True,
),
timeout=timeout,
),
)
EventHook = typing.Callable[..., typing.Any]
def _load_oci_config(
auth_type: str,
config_path: typing.Optional[str],
profile: typing.Optional[str],
**kwargs: typing.Any,
) -> typing.Dict[str, typing.Any]:
"""
Load OCI configuration based on authentication type.
Args:
auth_type: Authentication method (api_key, instance_principal, resource_principal)
config_path: Path to OCI config file (for api_key auth)
profile: Profile name in config file (for api_key auth)
**kwargs: Direct credentials (user_id, fingerprint, etc.)
Returns:
Dictionary containing OCI configuration
"""
oci = lazy_oci()
if auth_type == "instance_principal":
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
return {"signer": signer, "auth_type": "instance_principal"}
elif auth_type == "resource_principal":
signer = oci.auth.signers.get_resource_principals_signer()
return {"signer": signer, "auth_type": "resource_principal"}
elif kwargs.get("user_id"):
# Direct credentials provided - validate required fields
required_fields = ["fingerprint", "tenancy_id"]
missing = [f for f in required_fields if not kwargs.get(f)]
if missing:
raise ValueError(
f"When providing oci_user_id, you must also provide: {', '.join('oci_' + f for f in missing)}"
)
if not kwargs.get("private_key_path") and not kwargs.get("private_key_content"):
raise ValueError(
"When providing oci_user_id, you must also provide either "
"oci_private_key_path or oci_private_key_content"
)
config = {
"user": kwargs["user_id"],
"fingerprint": kwargs["fingerprint"],
"tenancy": kwargs["tenancy_id"],
}
if kwargs.get("private_key_path"):
config["key_file"] = kwargs["private_key_path"]
if kwargs.get("private_key_content"):
config["key_content"] = kwargs["private_key_content"]
return config
else:
# Load from config file
oci_config = oci.config.from_file(
file_location=config_path or "~/.oci/config", profile_name=profile or "DEFAULT"
)
_remove_inherited_session_auth(oci_config, config_path=config_path, profile=profile)
return oci_config
def _remove_inherited_session_auth(
oci_config: typing.Dict[str, typing.Any],
*,
config_path: typing.Optional[str],
profile: typing.Optional[str],
) -> None:
"""Drop session auth fields inherited from the OCI config DEFAULT section."""
profile_name = profile or "DEFAULT"
if profile_name == "DEFAULT" or "security_token_file" not in oci_config:
return
config_file = os.path.expanduser(config_path or "~/.oci/config")
parser = configparser.ConfigParser(interpolation=None)
if not parser.read(config_file):
return
if not parser.has_section(profile_name):
oci_config.pop("security_token_file", None)
return
explicit_security_token = False
current_section: typing.Optional[str] = None
with open(config_file, encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line or line.startswith(("#", ";")):
continue
if line.startswith("[") and line.endswith("]"):
current_section = line[1:-1].strip()
continue
if current_section == profile_name and line.split("=", 1)[0].strip() == "security_token_file":
explicit_security_token = True
break
if not explicit_security_token:
oci_config.pop("security_token_file", None)
def _usage_from_oci(usage_data: typing.Optional[typing.Dict[str, typing.Any]]) -> typing.Dict[str, typing.Any]:
usage_data = usage_data or {}
input_tokens = usage_data.get("inputTokens", 0)
output_tokens = usage_data.get("completionTokens", usage_data.get("outputTokens", 0))
return {
"tokens": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
"billed_units": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}
}
def get_event_hooks(
oci_config: typing.Dict[str, typing.Any],
oci_region: str,
oci_compartment_id: str,
is_v2_client: bool = False,
) -> typing.Dict[str, typing.List[EventHook]]:
"""
Create httpx event hooks for OCI request/response transformation.
Args:
oci_config: OCI configuration dictionary
oci_region: OCI region (e.g., "us-chicago-1")
oci_compartment_id: OCI compartment OCID
is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False)
Returns:
Dictionary of event hooks for httpx
"""
return {
"request": [
map_request_to_oci(
oci_config=oci_config,
oci_region=oci_region,
oci_compartment_id=oci_compartment_id,
is_v2_client=is_v2_client,
),
],
"response": [map_response_from_oci()],
}
def map_request_to_oci(
oci_config: typing.Dict[str, typing.Any],
oci_region: str,
oci_compartment_id: str,
is_v2_client: bool = False,
) -> EventHook:
"""
Create event hook that transforms Cohere requests to OCI format and signs them.
Args:
oci_config: OCI configuration dictionary
oci_region: OCI region
oci_compartment_id: OCI compartment OCID
is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False)
Returns:
Event hook function for httpx
"""
oci = lazy_oci()
# Create OCI signer based on config type
# Priority order: instance/resource principal > session-based auth > API key auth
if "signer" in oci_config:
signer = oci_config["signer"] # Instance/resource principal
elif "security_token_file" in oci_config:
# Session-based authentication with security token.
# The token file is re-read on every request so that OCI CLI token refreshes
# (e.g. `oci session refresh`) are picked up without restarting the client.
key_file = oci_config.get("key_file")
if not key_file:
raise ValueError(
"OCI config profile is missing 'key_file'. "
"Session-based auth requires a key_file entry in your OCI config profile."
)
token_file_path = os.path.expanduser(oci_config["security_token_file"])
private_key = oci.signer.load_private_key_from_file(os.path.expanduser(key_file))
class _RefreshingSecurityTokenSigner:
"""Wraps SecurityTokenSigner and re-reads the token file before each signing call."""
def __init__(self) -> None:
self._token_file = token_file_path
self._private_key = private_key
self._refresh()
def _refresh(self) -> None:
with open(self._token_file, "r") as _f:
_token = _f.read().strip()
self._signer = oci.auth.signers.SecurityTokenSigner(
token=_token,
private_key=self._private_key,
)
# Delegate all attribute access to the inner signer, refreshing first.
def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
self._refresh()
return self._signer(*args, **kwargs)
def __getattr__(self, name: str) -> typing.Any:
if name.startswith("_"):
raise AttributeError(name)
self._refresh()
return getattr(self._signer, name)
signer = _RefreshingSecurityTokenSigner()
elif "user" in oci_config:
signer = oci.signer.Signer(
tenancy=oci_config["tenancy"],
user=oci_config["user"],
fingerprint=oci_config["fingerprint"],
private_key_file_location=oci_config.get("key_file"),
private_key_content=oci_config.get("key_content"),
)
else:
# Config doesn't have user or security token - unsupported
raise ValueError(
"OCI config is missing 'user' field and no security_token_file found. "
"Please use a profile with standard API key authentication, "
"session-based authentication, or provide direct credentials via oci_user_id parameter."
)
def _event_hook(request: httpx.Request) -> None:
# Extract Cohere API details
path_parts = request.url.path.split("/")
endpoint = path_parts[-1]
body = json.loads(request.read())
# Build OCI URL
url = get_oci_url(
region=oci_region,
endpoint=endpoint,
)
# Transform request body to OCI format
oci_body = transform_request_to_oci(
endpoint=endpoint,
cohere_body=body,
compartment_id=oci_compartment_id,
is_v2=is_v2_client,
)
# Prepare request for signing
oci_body_bytes = json.dumps(oci_body).encode("utf-8")
# Build headers for signing
headers = {
"content-type": "application/json",
"date": email.utils.formatdate(usegmt=True),
}
# Create a requests.PreparedRequest for OCI signing
oci_request = requests.Request(
method=request.method,
url=url,
headers=headers,
data=oci_body_bytes,
)
prepped_request = oci_request.prepare()
# Sign the request using OCI signer (modifies headers in place)
signer.do_request_sign(prepped_request)
# Update httpx request with signed headers
request.url = URL(url)
request.headers = httpx.Headers(prepped_request.headers)
request.stream = ByteStream(oci_body_bytes)
request._content = oci_body_bytes
request.extensions["endpoint"] = endpoint
request.extensions["is_stream"] = body.get("stream", False)
request.extensions["is_v2"] = is_v2_client
return _event_hook
def map_response_from_oci() -> EventHook:
"""
Create event hook that transforms OCI responses to Cohere format.
Returns:
Event hook function for httpx
"""
def _hook(response: httpx.Response) -> None:
endpoint = response.request.extensions["endpoint"]
is_stream = response.request.extensions.get("is_stream", False)
is_v2 = response.request.extensions.get("is_v2", False)
output: typing.Iterator[bytes]
# Only transform successful responses (200-299)
# Let error responses pass through unchanged so SDK error handling works
if not (200 <= response.status_code < 300):
return
# For streaming responses, wrap the stream with a transformer
if is_stream:
original_stream = typing.cast(typing.Iterator[bytes], response.stream)
transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint, is_v2)
response.stream = Streamer(transformed_stream)
# Reset consumption flags
if hasattr(response, "_content"):
del response._content
response.is_stream_consumed = False
response.is_closed = False
return
# Handle non-streaming responses
oci_response = json.loads(response.read())
cohere_response = transform_oci_response_to_cohere(endpoint, oci_response, is_v2)
output = iter([json.dumps(cohere_response).encode("utf-8")])
response.stream = Streamer(output)
# Reset response for re-reading
if hasattr(response, "_content"):
del response._content
response.is_stream_consumed = False
response.is_closed = False
return _hook
def get_oci_url(
region: str,
endpoint: str,
) -> str:
"""
Map Cohere endpoints to OCI Generative AI endpoints.
Args:
region: OCI region (e.g., "us-chicago-1")
endpoint: Cohere endpoint name
Returns:
Full OCI Generative AI endpoint URL
"""
base = f"https://inference.generativeai.{region}.oci.oraclecloud.com"
api_version = "20231130"
# Map Cohere endpoints to OCI actions
action_map = {
"embed": "embedText",
"chat": "chat",
}
action = action_map.get(endpoint)
if action is None:
raise ValueError(
f"Endpoint '{endpoint}' is not supported by OCI Generative AI. "
f"Supported endpoints: {list(action_map.keys())}"
)
return f"{base}/{api_version}/actions/{action}"
def normalize_model_for_oci(model: str) -> str:
"""
Normalize model name for OCI.
OCI accepts model names in the format "cohere.model-name" or full OCIDs.
This function ensures proper formatting for all regions.
Args:
model: Model name (e.g., "command-r-08-2024") or full OCID
Returns:
Normalized model identifier (e.g., "cohere.command-r-08-2024" or OCID)
Examples:
>>> normalize_model_for_oci("command-a-03-2025")
"cohere.command-a-03-2025"
>>> normalize_model_for_oci("cohere.embed-english-v3.0")
"cohere.embed-english-v3.0"
>>> normalize_model_for_oci("ocid1.generativeaimodel.oc1...")
"ocid1.generativeaimodel.oc1..."
"""
if not model:
raise ValueError("OCI requests require a non-empty model name")
# If it's already an OCID, return as-is (works across all regions)
if model.startswith("ocid1."):
return model
# Add "cohere." prefix if not present
if not model.startswith("cohere."):
return f"cohere.{model}"
return model
def transform_request_to_oci(
endpoint: str,
cohere_body: typing.Dict[str, typing.Any],
compartment_id: str,
is_v2: bool = False,
) -> typing.Dict[str, typing.Any]:
"""
Transform Cohere request body to OCI format.
Args:
endpoint: Cohere endpoint name
cohere_body: Original Cohere request body
compartment_id: OCI compartment OCID
is_v2: Whether this request comes from OciClientV2 (True) or OciClient (False)
Returns:
Transformed request body in OCI format
"""
model = normalize_model_for_oci(cohere_body.get("model", ""))
if endpoint == "embed":
if "texts" in cohere_body:
inputs = cohere_body["texts"]
elif "inputs" in cohere_body:
inputs = cohere_body["inputs"]
elif "images" in cohere_body:
raise ValueError("OCI embed does not support the top-level 'images' parameter; use 'inputs' instead")
else:
raise ValueError("OCI embed requires either 'texts' or 'inputs'")
oci_body = {
"inputs": inputs,
"servingMode": {
"servingType": "ON_DEMAND",
"modelId": model,
},
"compartmentId": compartment_id,
}
# Add optional fields only if provided
if "input_type" in cohere_body:
oci_body["inputType"] = cohere_body["input_type"].upper()
if "truncate" in cohere_body:
oci_body["truncate"] = cohere_body["truncate"].upper()
if "embedding_types" in cohere_body:
# OCI expects lowercase embedding types (float, int8, binary, etc.)
oci_body["embeddingTypes"] = [et.lower() for et in cohere_body["embedding_types"]]
if "max_tokens" in cohere_body:
oci_body["maxTokens"] = cohere_body["max_tokens"]
if "output_dimension" in cohere_body:
oci_body["outputDimension"] = cohere_body["output_dimension"]
if "priority" in cohere_body:
oci_body["priority"] = cohere_body["priority"]
return oci_body
elif endpoint == "chat":
# Validate that the request body matches the client type
has_messages = "messages" in cohere_body
has_message = "message" in cohere_body
if is_v2 and not has_messages:
raise ValueError(
"OciClientV2 requires the V2 API format ('messages' array). "
"Got a V1-style request with 'message' string. "
"Use OciClient for V1 models like Command R, "
"or switch to the V2 messages format."
)
if not is_v2 and has_messages and not has_message:
raise ValueError(
"OciClient uses the V1 API format (single 'message' string). "
"Got a V2-style request with 'messages' array. "
"Use OciClientV2 for V2 models like Command A."
)
chat_request: typing.Dict[str, typing.Any] = {
"apiFormat": "COHEREV2" if is_v2 else "COHERE",
}
if is_v2:
# V2: Transform Cohere V2 messages to OCI V2 format
# Cohere sends: [{"role": "user", "content": "text"}]
# OCI expects: [{"role": "USER", "content": [{"type": "TEXT", "text": "..."}]}]
oci_messages = []
for msg in cohere_body["messages"]:
oci_msg: typing.Dict[str, typing.Any] = {
"role": msg["role"].upper(),
}
# Transform content
if isinstance(msg.get("content"), str):
oci_msg["content"] = [{"type": "TEXT", "text": msg["content"]}]
elif isinstance(msg.get("content"), list):
transformed_content = []
for item in msg["content"]:
if isinstance(item, dict) and "type" in item:
transformed_item = item.copy()
transformed_item["type"] = item["type"].upper()
# OCI expects camelCase: image_url → imageUrl
if "image_url" in transformed_item:
transformed_item["imageUrl"] = transformed_item.pop("image_url")
transformed_content.append(transformed_item)
else:
transformed_content.append(item)
oci_msg["content"] = transformed_content
else:
oci_msg["content"] = msg.get("content") or []
if "tool_calls" in msg:
oci_tool_calls = []
for tc in msg["tool_calls"]:
oci_tc = {**tc}
if "type" in oci_tc:
oci_tc["type"] = oci_tc["type"].upper()
oci_tool_calls.append(oci_tc)
oci_msg["toolCalls"] = oci_tool_calls
if "tool_call_id" in msg:
oci_msg["toolCallId"] = msg["tool_call_id"]
if "tool_plan" in msg:
oci_msg["toolPlan"] = msg["tool_plan"]
oci_messages.append(oci_msg)
chat_request["messages"] = oci_messages
# V2 optional parameters
if "max_tokens" in cohere_body:
chat_request["maxTokens"] = cohere_body["max_tokens"]
if "temperature" in cohere_body:
chat_request["temperature"] = cohere_body["temperature"]
if "k" in cohere_body:
chat_request["topK"] = cohere_body["k"]
if "p" in cohere_body:
chat_request["topP"] = cohere_body["p"]
if "seed" in cohere_body:
chat_request["seed"] = cohere_body["seed"]
if "frequency_penalty" in cohere_body:
chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"]
if "presence_penalty" in cohere_body:
chat_request["presencePenalty"] = cohere_body["presence_penalty"]
if "stop_sequences" in cohere_body:
chat_request["stopSequences"] = cohere_body["stop_sequences"]
if "tools" in cohere_body:
oci_tools = []
for tool in cohere_body["tools"]:
oci_tool = {**tool}
if "type" in oci_tool:
oci_tool["type"] = oci_tool["type"].upper()
oci_tools.append(oci_tool)
chat_request["tools"] = oci_tools
if "strict_tools" in cohere_body:
chat_request["strictTools"] = cohere_body["strict_tools"]
if "documents" in cohere_body:
chat_request["documents"] = cohere_body["documents"]
if "citation_options" in cohere_body:
chat_request["citationOptions"] = cohere_body["citation_options"]
if "response_format" in cohere_body:
chat_request["responseFormat"] = cohere_body["response_format"]
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
if "logprobs" in cohere_body:
chat_request["logprobs"] = cohere_body["logprobs"]
if "tool_choice" in cohere_body:
chat_request["toolChoice"] = cohere_body["tool_choice"]
if "priority" in cohere_body:
chat_request["priority"] = cohere_body["priority"]
# Thinking parameter for Command A Reasoning models
if "thinking" in cohere_body and cohere_body["thinking"] is not None:
thinking = cohere_body["thinking"]
oci_thinking: typing.Dict[str, typing.Any] = {}
if "type" in thinking:
oci_thinking["type"] = thinking["type"].upper()
if "token_budget" in thinking and thinking["token_budget"] is not None:
oci_thinking["tokenBudget"] = thinking["token_budget"]
if oci_thinking:
chat_request["thinking"] = oci_thinking
else:
# V1: single message string
chat_request["message"] = cohere_body["message"]
if "temperature" in cohere_body:
chat_request["temperature"] = cohere_body["temperature"]
if "max_tokens" in cohere_body:
chat_request["maxTokens"] = cohere_body["max_tokens"]
if "k" in cohere_body:
chat_request["topK"] = cohere_body["k"]
if "p" in cohere_body:
chat_request["topP"] = cohere_body["p"]
if "seed" in cohere_body:
chat_request["seed"] = cohere_body["seed"]
if "stop_sequences" in cohere_body:
chat_request["stopSequences"] = cohere_body["stop_sequences"]
if "frequency_penalty" in cohere_body:
chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"]
if "presence_penalty" in cohere_body:
chat_request["presencePenalty"] = cohere_body["presence_penalty"]
if "preamble" in cohere_body:
chat_request["preambleOverride"] = cohere_body["preamble"]
if "chat_history" in cohere_body:
chat_request["chatHistory"] = cohere_body["chat_history"]
if "documents" in cohere_body:
chat_request["documents"] = cohere_body["documents"]
if "tools" in cohere_body:
oci_tools = []
for tool in cohere_body["tools"]:
oci_tool = {**tool}
if "type" in oci_tool:
oci_tool["type"] = oci_tool["type"].upper()
oci_tools.append(oci_tool)
chat_request["tools"] = oci_tools
if "tool_results" in cohere_body:
chat_request["toolResults"] = cohere_body["tool_results"]
if "response_format" in cohere_body:
chat_request["responseFormat"] = cohere_body["response_format"]
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
if "priority" in cohere_body:
chat_request["priority"] = cohere_body["priority"]
# Handle streaming for both versions
if cohere_body.get("stream"):
chat_request["isStream"] = True
# Top level OCI request structure
oci_body = {
"servingMode": {
"servingType": "ON_DEMAND",
"modelId": model,
},
"compartmentId": compartment_id,
"chatRequest": chat_request,
}
return oci_body
raise ValueError(
f"Endpoint '{endpoint}' is not supported by OCI Generative AI on-demand inference. "
"Supported endpoints: ['embed', 'chat']"
)
def transform_oci_response_to_cohere(
endpoint: str, oci_response: typing.Dict[str, typing.Any], is_v2: bool = False,
) -> typing.Dict[str, typing.Any]:
"""
Transform OCI response to Cohere format.
Args:
endpoint: Cohere endpoint name
oci_response: OCI response body
is_v2: Whether this is a V2 API response
Returns:
Transformed response in Cohere format
"""
if endpoint == "embed":
# OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified
embeddings_data = oci_response.get("embeddingsByType") or oci_response.get("embeddings", {})
if isinstance(embeddings_data, dict):
normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()}
else:
normalized_embeddings = {"float": embeddings_data}
if is_v2:
embeddings = normalized_embeddings
else:
embeddings = normalized_embeddings.get("float", [])
meta = {
"api_version": {"version": "1"},
}
usage = _usage_from_oci(oci_response.get("usage"))
if "tokens" in usage:
meta["tokens"] = usage["tokens"]
if "billed_units" in usage:
meta["billed_units"] = usage["billed_units"]
response_type = "embeddings_by_type" if is_v2 else "embeddings_floats"
return {
"response_type": response_type,
"id": oci_response.get("id", str(uuid.uuid4())),
"embeddings": embeddings,
"texts": [],
"meta": meta,
}
elif endpoint == "chat":
chat_response = oci_response.get("chatResponse", {})
if is_v2:
usage = _usage_from_oci(chat_response.get("usage"))
message = chat_response.get("message", {})
if "role" in message:
message = {**message, "role": message["role"].lower()}
if "content" in message and isinstance(message["content"], list):
transformed_content = []
for item in message["content"]:
if isinstance(item, dict):
transformed_item = item.copy()
if "type" in transformed_item:
transformed_item["type"] = transformed_item["type"].lower()
transformed_content.append(transformed_item)
else:
transformed_content.append(item)
message = {**message, "content": transformed_content}
if "toolCalls" in message:
tool_calls = []
for tc in message["toolCalls"]:
lowered_tc = {**tc}
if "type" in lowered_tc:
lowered_tc["type"] = lowered_tc["type"].lower()
tool_calls.append(lowered_tc)
message = {k: v for k, v in message.items() if k != "toolCalls"}
message["tool_calls"] = tool_calls
if "toolPlan" in message:
tool_plan = message["toolPlan"]
message = {k: v for k, v in message.items() if k != "toolPlan"}
message["tool_plan"] = tool_plan
return {
"id": chat_response.get("id", str(uuid.uuid4())),
"message": message,
"finish_reason": chat_response.get("finishReason", "COMPLETE"),
"usage": usage,
}
# V1 response
meta = {
"api_version": {"version": "1"},
}
usage = _usage_from_oci(chat_response.get("usage"))
if "tokens" in usage:
meta["tokens"] = usage["tokens"]
if "billed_units" in usage:
meta["billed_units"] = usage["billed_units"]
return {
"text": chat_response.get("text", ""),
"generation_id": str(uuid.uuid4()),
"chat_history": chat_response.get("chatHistory", []),
"finish_reason": chat_response.get("finishReason", "COMPLETE"),
"citations": chat_response.get("citations", []),
"documents": chat_response.get("documents", []),
"search_queries": chat_response.get("searchQueries", []),
"meta": meta,
}
return oci_response
def transform_oci_stream_wrapper(
stream: typing.Iterator[bytes], endpoint: str, is_v2: bool = False,
) -> typing.Iterator[bytes]:
"""
Wrap OCI stream and transform events to Cohere format.
Args:
stream: Original OCI stream iterator
endpoint: Cohere endpoint name
is_v2: Whether this is a V2 API stream
Yields:
Bytes of transformed streaming events
"""
generation_id = str(uuid.uuid4())
emitted_start = False
emitted_content_end = False
current_content_type: typing.Optional[str] = None
current_content_index = 0
final_finish_reason = "COMPLETE"
final_usage: typing.Optional[typing.Dict[str, typing.Any]] = None
full_v1_text = ""
final_v1_finish_reason = "COMPLETE"
buffer = b""
def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes:
return b"data: " + json.dumps(event).encode("utf-8") + b"\n\n"
def _emit_v1_event(event: typing.Dict[str, typing.Any]) -> bytes:
return json.dumps(event).encode("utf-8") + b"\n"
def _current_content_type(oci_event: typing.Dict[str, typing.Any]) -> typing.Optional[str]:
message = oci_event.get("message")
if isinstance(message, dict):
content_list = message.get("content")
if content_list and isinstance(content_list, list) and len(content_list) > 0:
oci_type = content_list[0].get("type", "TEXT").upper()
return "thinking" if oci_type == "THINKING" else "text"
return None # finish-only or non-content event — don't trigger a type transition
def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]:
nonlocal emitted_start, emitted_content_end, current_content_type, current_content_index
nonlocal final_finish_reason, final_usage
event_content_type = _current_content_type(oci_event)
open_type = event_content_type or "text"
if not emitted_start:
yield _emit_v2_event(
{
"type": "message-start",
"id": generation_id,
"delta": {"message": {"role": "assistant"}},
}
)
yield _emit_v2_event(
{
"type": "content-start",
"index": current_content_index,
"delta": {"message": {"content": {"type": open_type}}},
}
)
emitted_start = True
current_content_type = open_type
elif event_content_type is not None and current_content_type != event_content_type:
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
current_content_index += 1
yield _emit_v2_event(
{
"type": "content-start",
"index": current_content_index,
"delta": {"message": {"content": {"type": event_content_type}}},
}
)
current_content_type = event_content_type
emitted_content_end = False
for cohere_event in typing.cast(
typing.List[typing.Dict[str, typing.Any]], transform_stream_event(endpoint, oci_event, is_v2=True)
):
if "index" in cohere_event:
cohere_event = {**cohere_event, "index": current_content_index}
if cohere_event["type"] == "content-end":
emitted_content_end = True
final_finish_reason = oci_event.get("finishReason", final_finish_reason)
final_usage = _usage_from_oci(oci_event.get("usage"))
yield _emit_v2_event(cohere_event)
def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]:
nonlocal emitted_start, full_v1_text, final_v1_finish_reason
if not emitted_start:
yield _emit_v1_event({
"event_type": "stream-start",
"generation_id": generation_id,
"is_finished": False,
})
emitted_start = True
event = transform_stream_event(endpoint, oci_event, is_v2=False)
if isinstance(event, dict):
if event.get("event_type") == "text-generation" and event.get("text"):
full_v1_text += typing.cast(str, event["text"])
if "finishReason" in oci_event:
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
yield _emit_v1_event(event)
stream_finished = False
def _emit_closing_events() -> typing.Iterator[bytes]:
"""Emit the final closing events for the stream."""
if is_v2:
if emitted_start:
if not emitted_content_end:
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
message_end_event: typing.Dict[str, typing.Any] = {
"type": "message-end",
"id": generation_id,
"delta": {"finish_reason": final_finish_reason},
}
if final_usage:
message_end_event["delta"]["usage"] = final_usage
yield _emit_v2_event(message_end_event)
else:
yield _emit_v1_event(
{
"event_type": "stream-end",
"finish_reason": final_v1_finish_reason,
"response": {
"text": full_v1_text,
"generation_id": generation_id,
"finish_reason": final_v1_finish_reason,
},
}
)
def _process_line(line: str) -> typing.Iterator[bytes]:
nonlocal stream_finished
if not line.startswith("data: "):
return
data_str = line[6:]
if data_str.strip() == "[DONE]":
for event_bytes in _emit_closing_events():
yield event_bytes
stream_finished = True
return
try:
oci_event = json.loads(data_str)
except json.JSONDecodeError:
return
try:
if is_v2:
for event_bytes in _transform_v2_event(oci_event):
yield event_bytes
else:
for event_bytes in _transform_v1_event(oci_event):
yield event_bytes
except Exception as exc:
raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc
# OCI may not send [DONE] — treat finishReason as stream termination
if "finishReason" in oci_event:
for event_bytes in _emit_closing_events():
yield event_bytes
stream_finished = True
for chunk in stream:
buffer += chunk
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
line = line_bytes.decode("utf-8").strip()
for event_bytes in _process_line(line):
yield event_bytes
if stream_finished:
return
if buffer.strip() and not stream_finished:
line = buffer.decode("utf-8").strip()
for event_bytes in _process_line(line):
yield event_bytes
def transform_stream_event(
endpoint: str, oci_event: typing.Dict[str, typing.Any], is_v2: bool = False,
) -> typing.Union[typing.Dict[str, typing.Any], typing.List[typing.Dict[str, typing.Any]]]:
"""
Transform individual OCI stream event to Cohere format.
Args:
endpoint: Cohere endpoint name
oci_event: OCI stream event
is_v2: Whether this is a V2 API stream
Returns:
V2: List of transformed events. V1: Single transformed event dict.
"""
if endpoint == "chat":
if is_v2:
content_type = "text"
content_value = ""
message = oci_event.get("message")
if "message" in oci_event and not isinstance(message, dict):
raise TypeError("OCI V2 stream event message must be an object")
if isinstance(message, dict) and "content" in message:
content_list = message["content"]
if content_list and isinstance(content_list, list) and len(content_list) > 0:
first_content = content_list[0]
oci_type = first_content.get("type", "TEXT").upper()
if oci_type == "THINKING":
content_type = "thinking"
content_value = first_content.get("thinking", "")
else:
content_type = "text"
content_value = first_content.get("text", "")
events: typing.List[typing.Dict[str, typing.Any]] = []
if content_value:
delta_content: typing.Dict[str, typing.Any] = {}
if content_type == "thinking":
delta_content["thinking"] = content_value
else:
delta_content["text"] = content_value
events.append(
{
"type": "content-delta",
"index": 0,
"delta": {
"message": {
"content": delta_content,
}
},
}
)
if "finishReason" in oci_event:
events.append(
{
"type": "content-end",
"index": 0,
}
)
return events
# V1 stream event
return {
"event_type": "text-generation",
"text": oci_event.get("text", ""),
"is_finished": oci_event.get("isFinished", False),
}
return [] if is_v2 else {}
================================================
FILE: src/cohere/overrides.py
================================================
import typing
import uuid
from . import EmbedByTypeResponseEmbeddings
from .core.pydantic_utilities import _get_model_fields, Model, IS_PYDANTIC_V2
from pprint import pprint
def get_fields(obj) -> typing.List[str]:
return [str(x) for x in _get_model_fields(obj).keys()]
def get_aliases_or_field(obj) -> typing.List[str]:
return [
field_info.alias or (field_info and field_info.metadata and field_info.metadata[0] and field_info.metadata[0].alias) or field_name # type: ignore
for field_name, field_info
in _get_model_fields(obj).items()
]
def get_aliases_and_fields(obj):
# merge and dedup get_fields(obj), get_aliases_or_field(obj)
return list(set(get_fields(obj) + get_aliases_or_field(obj)))
def allow_access_to_aliases(self: typing.Type["Model"], name):
for field_name, field_info in _get_model_fields(self).items():
alias = field_info.alias or (
field_info and field_info.metadata and field_info.metadata[0] and field_info.metadata[0].alias) # type: ignore
if alias == name or field_name == name:
return getattr(self, field_name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'")
def make_tool_call_v2_id_optional(cls):
"""
Override ToolCallV2 to make the 'id' field optional with a default UUID.
This ensures backward compatibility with code that doesn't provide an id.
We wrap the __init__ method to inject a default id before Pydantic validation runs.
"""
# Store the original __init__ method
original_init = cls.__init__
def patched_init(self, /, **data):
"""Patched __init__ that injects default id if not provided."""
# Inject default UUID if 'id' is not in the data
if 'id' not in data:
data['id'] = str(uuid.uuid4())
# Call the original __init__ with modified data
original_init(self, **data)
# Replace the __init__ method
cls.__init__ = patched_init
return cls
def run_overrides():
"""
These are overrides to allow us to make changes to generated code without touching the generated files themselves.
Should be used judiciously!
"""
# Override to allow access to aliases in EmbedByTypeResponseEmbeddings eg embeddings.float rather than embeddings.float_
setattr(EmbedByTypeResponseEmbeddings, "__getattr__", allow_access_to_aliases)
# Import ToolCallV2 lazily to avoid circular dependency issues
from . import ToolCallV2
# Override ToolCallV2 to make id field optional with default UUID
make_tool_call_v2_id_optional(ToolCallV2)
# Run overrides immediately at module import time to ensure they're applied
# before any code tries to use the modified classes
run_overrides()
================================================
FILE: src/cohere/py.typed
================================================
================================================
FILE: src/cohere/raw_base_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import contextlib
import json
import typing
from json.decoder import JSONDecodeError
from .core.api_error import ApiError
from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from .core.http_response import AsyncHttpResponse, HttpResponse
from .core.parse_error import ParsingError
from .core.request_options import RequestOptions
from .core.serialization import convert_and_respect_annotation_metadata
from .core.unchecked_base_model import construct_type
from .errors.bad_request_error import BadRequestError
from .errors.client_closed_request_error import ClientClosedRequestError
from .errors.forbidden_error import ForbiddenError
from .errors.gateway_timeout_error import GatewayTimeoutError
from .errors.internal_server_error import InternalServerError
from .errors.invalid_token_error import InvalidTokenError
from .errors.not_found_error import NotFoundError
from .errors.not_implemented_error import NotImplementedError
from .errors.service_unavailable_error import ServiceUnavailableError
from .errors.too_many_requests_error import TooManyRequestsError
from .errors.unauthorized_error import UnauthorizedError
from .errors.unprocessable_entity_error import UnprocessableEntityError
from .types.chat_connector import ChatConnector
from .types.chat_document import ChatDocument
from .types.chat_request_citation_quality import ChatRequestCitationQuality
from .types.chat_request_prompt_truncation import ChatRequestPromptTruncation
from .types.chat_request_safety_mode import ChatRequestSafetyMode
from .types.chat_stream_request_citation_quality import ChatStreamRequestCitationQuality
from .types.chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation
from .types.chat_stream_request_safety_mode import ChatStreamRequestSafetyMode
from .types.check_api_key_response import CheckApiKeyResponse
from .types.classify_example import ClassifyExample
from .types.classify_request_truncate import ClassifyRequestTruncate
from .types.classify_response import ClassifyResponse
from .types.detokenize_response import DetokenizeResponse
from .types.embed_input_type import EmbedInputType
from .types.embed_request_truncate import EmbedRequestTruncate
from .types.embed_response import EmbedResponse
from .types.embedding_type import EmbeddingType
from .types.generate_request_return_likelihoods import GenerateRequestReturnLikelihoods
from .types.generate_request_truncate import GenerateRequestTruncate
from .types.generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods
from .types.generate_stream_request_truncate import GenerateStreamRequestTruncate
from .types.generate_streamed_response import GenerateStreamedResponse
from .types.generation import Generation
from .types.message import Message
from .types.non_streamed_chat_response import NonStreamedChatResponse
from .types.rerank_request_documents_item import RerankRequestDocumentsItem
from .types.rerank_response import RerankResponse
from .types.response_format import ResponseFormat
from .types.streamed_chat_response import StreamedChatResponse
from .types.summarize_request_extractiveness import SummarizeRequestExtractiveness
from .types.summarize_request_format import SummarizeRequestFormat
from .types.summarize_request_length import SummarizeRequestLength
from .types.summarize_response import SummarizeResponse
from .types.tokenize_response import TokenizeResponse
from .types.tool import Tool
from .types.tool_result import ToolResult
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawBaseCohere:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
@contextlib.contextmanager
def chat_stream(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[HttpResponse[typing.Iterator[StreamedChatResponse]]]:
"""
Generates a streamed text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatStreamRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[HttpResponse[typing.Iterator[StreamedChatResponse]]]
"""
with self._client_wrapper.httpx_client.stream(
"v1/chat",
method="POST",
json={
"message": message,
"model": model,
"preamble": preamble,
"chat_history": convert_and_respect_annotation_metadata(
object_=chat_history, annotation=typing.Sequence[Message], direction="write"
),
"conversation_id": conversation_id,
"prompt_truncation": prompt_truncation,
"connectors": convert_and_respect_annotation_metadata(
object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write"
),
"search_queries_only": search_queries_only,
"documents": documents,
"citation_quality": citation_quality,
"temperature": temperature,
"max_tokens": max_tokens,
"max_input_tokens": max_input_tokens,
"k": k,
"p": p,
"seed": seed,
"stop_sequences": stop_sequences,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"raw_prompting": raw_prompting,
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[Tool], direction="write"
),
"tool_results": convert_and_respect_annotation_metadata(
object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write"
),
"force_single_step": force_single_step,
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormat, direction="write"
),
"safety_mode": safety_mode,
"stream": True,
},
headers={
"content-type": "application/json",
"Accepts": str(accepts) if accepts is not None else None,
},
request_options=request_options,
omit=OMIT,
) as _response:
def _stream() -> HttpResponse[typing.Iterator[StreamedChatResponse]]:
try:
if 200 <= _response.status_code < 300:
def _iter():
for _text in _response.iter_lines():
try:
if len(_text) == 0:
continue
yield typing.cast(
StreamedChatResponse,
construct_type(
type_=StreamedChatResponse, # type: ignore
object_=json.loads(_text),
),
)
except Exception:
pass
return
return HttpResponse(response=_response, data=_iter())
_response.read()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield _stream()
def chat(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[NonStreamedChatResponse]:
"""
Generates a text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[NonStreamedChatResponse]
"""
_response = self._client_wrapper.httpx_client.request(
"v1/chat",
method="POST",
json={
"message": message,
"model": model,
"preamble": preamble,
"chat_history": convert_and_respect_annotation_metadata(
object_=chat_history, annotation=typing.Sequence[Message], direction="write"
),
"conversation_id": conversation_id,
"prompt_truncation": prompt_truncation,
"connectors": convert_and_respect_annotation_metadata(
object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write"
),
"search_queries_only": search_queries_only,
"documents": documents,
"citation_quality": citation_quality,
"temperature": temperature,
"max_tokens": max_tokens,
"max_input_tokens": max_input_tokens,
"k": k,
"p": p,
"seed": seed,
"stop_sequences": stop_sequences,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"raw_prompting": raw_prompting,
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[Tool], direction="write"
),
"tool_results": convert_and_respect_annotation_metadata(
object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write"
),
"force_single_step": force_single_step,
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormat, direction="write"
),
"safety_mode": safety_mode,
"stream": False,
},
headers={
"content-type": "application/json",
"Accepts": str(accepts) if accepts is not None else None,
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
NonStreamedChatResponse,
construct_type(
type_=NonStreamedChatResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
@contextlib.contextmanager
def generate_stream(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[HttpResponse[typing.Iterator[GenerateStreamedResponse]]]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateStreamRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[HttpResponse[typing.Iterator[GenerateStreamedResponse]]]
"""
with self._client_wrapper.httpx_client.stream(
"v1/generate",
method="POST",
json={
"prompt": prompt,
"model": model,
"num_generations": num_generations,
"max_tokens": max_tokens,
"truncate": truncate,
"temperature": temperature,
"seed": seed,
"preset": preset,
"end_sequences": end_sequences,
"stop_sequences": stop_sequences,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"return_likelihoods": return_likelihoods,
"raw_prompting": raw_prompting,
"stream": True,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
) as _response:
def _stream() -> HttpResponse[typing.Iterator[GenerateStreamedResponse]]:
try:
if 200 <= _response.status_code < 300:
def _iter():
for _text in _response.iter_lines():
try:
if len(_text) == 0:
continue
yield typing.cast(
GenerateStreamedResponse,
construct_type(
type_=GenerateStreamedResponse, # type: ignore
object_=json.loads(_text),
),
)
except Exception:
pass
return
return HttpResponse(response=_response, data=_iter())
_response.read()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield _stream()
def generate(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[Generation]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[Generation]
"""
_response = self._client_wrapper.httpx_client.request(
"v1/generate",
method="POST",
json={
"prompt": prompt,
"model": model,
"num_generations": num_generations,
"max_tokens": max_tokens,
"truncate": truncate,
"temperature": temperature,
"seed": seed,
"preset": preset,
"end_sequences": end_sequences,
"stop_sequences": stop_sequences,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"return_likelihoods": return_likelihoods,
"raw_prompting": raw_prompting,
"stream": False,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
Generation,
construct_type(
type_=Generation, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[EmbedResponse]:
"""
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents.
Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Images are only supported with Embed v3.0 and newer models.
model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : typing.Optional[EmbedInputType]
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[EmbedResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/embed",
method="POST",
json={
"texts": texts,
"images": images,
"model": model,
"input_type": input_type,
"embedding_types": embedding_types,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedResponse,
construct_type(
type_=EmbedResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def rerank(
self,
*,
query: str,
documents: typing.Sequence[RerankRequestDocumentsItem],
model: typing.Optional[str] = OMIT,
top_n: typing.Optional[int] = OMIT,
rank_fields: typing.Optional[typing.Sequence[str]] = OMIT,
return_documents: typing.Optional[bool] = OMIT,
max_chunks_per_doc: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[RerankResponse]:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
query : str
The search query
documents : typing.Sequence[RerankRequestDocumentsItem]
A list of document objects or strings to rerank.
If a document is provided the text fields is required and all other fields will be preserved in the response.
The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
rank_fields : typing.Optional[typing.Sequence[str]]
If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
return_documents : typing.Optional[bool]
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
max_chunks_per_doc : typing.Optional[int]
The maximum number of chunks to produce internally from a document
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[RerankResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/rerank",
method="POST",
json={
"model": model,
"query": query,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[RerankRequestDocumentsItem], direction="write"
),
"top_n": top_n,
"rank_fields": rank_fields,
"return_documents": return_documents,
"max_chunks_per_doc": max_chunks_per_doc,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
RerankResponse,
construct_type(
type_=RerankResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def classify(
self,
*,
inputs: typing.Sequence[str],
examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT,
model: typing.Optional[str] = OMIT,
preset: typing.Optional[str] = OMIT,
truncate: typing.Optional[ClassifyRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[ClassifyResponse]:
"""
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference.
Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
Parameters
----------
inputs : typing.Sequence[str]
A list of up to 96 texts to be classified. Each one must be a non-empty string.
There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models).
Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
examples : typing.Optional[typing.Sequence[ClassifyExample]]
An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`.
Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
model : typing.Optional[str]
ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
preset : typing.Optional[str]
The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
truncate : typing.Optional[ClassifyRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[ClassifyResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/classify",
method="POST",
json={
"inputs": inputs,
"examples": convert_and_respect_annotation_metadata(
object_=examples, annotation=typing.Sequence[ClassifyExample], direction="write"
),
"model": model,
"preset": preset,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ClassifyResponse,
construct_type(
type_=ClassifyResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def summarize(
self,
*,
text: str,
length: typing.Optional[SummarizeRequestLength] = OMIT,
format: typing.Optional[SummarizeRequestFormat] = OMIT,
model: typing.Optional[str] = OMIT,
extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT,
temperature: typing.Optional[float] = OMIT,
additional_command: typing.Optional[str] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[SummarizeResponse]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates a summary in English for a given text.
Parameters
----------
text : str
The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
length : typing.Optional[SummarizeRequestLength]
One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
format : typing.Optional[SummarizeRequestFormat]
One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
model : typing.Optional[str]
The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
extractiveness : typing.Optional[SummarizeRequestExtractiveness]
One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
temperature : typing.Optional[float]
Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
additional_command : typing.Optional[str]
A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[SummarizeResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/summarize",
method="POST",
json={
"text": text,
"length": length,
"format": format,
"model": model,
"extractiveness": extractiveness,
"temperature": temperature,
"additional_command": additional_command,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
SummarizeResponse,
construct_type(
type_=SummarizeResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def tokenize(
self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[TokenizeResponse]:
"""
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
text : str
The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
model : str
The input will be tokenized by the tokenizer that is used by this model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[TokenizeResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/tokenize",
method="POST",
json={
"text": text,
"model": model,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
TokenizeResponse,
construct_type(
type_=TokenizeResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def detokenize(
self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[DetokenizeResponse]:
"""
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
tokens : typing.Sequence[int]
The list of tokens to be detokenized.
model : str
An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[DetokenizeResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/detokenize",
method="POST",
json={
"tokens": tokens,
"model": model,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DetokenizeResponse,
construct_type(
type_=DetokenizeResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def check_api_key(
self, *, request_options: typing.Optional[RequestOptions] = None
) -> HttpResponse[CheckApiKeyResponse]:
"""
Checks that the api key in the Authorization header is valid and active
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[CheckApiKeyResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v1/check-api-key",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CheckApiKeyResponse,
construct_type(
type_=CheckApiKeyResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawBaseCohere:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
@contextlib.asynccontextmanager
async def chat_stream(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatStreamRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatStreamRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatStreamRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]]:
"""
Generates a streamed text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatStreamRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatStreamRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]]
"""
async with self._client_wrapper.httpx_client.stream(
"v1/chat",
method="POST",
json={
"message": message,
"model": model,
"preamble": preamble,
"chat_history": convert_and_respect_annotation_metadata(
object_=chat_history, annotation=typing.Sequence[Message], direction="write"
),
"conversation_id": conversation_id,
"prompt_truncation": prompt_truncation,
"connectors": convert_and_respect_annotation_metadata(
object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write"
),
"search_queries_only": search_queries_only,
"documents": documents,
"citation_quality": citation_quality,
"temperature": temperature,
"max_tokens": max_tokens,
"max_input_tokens": max_input_tokens,
"k": k,
"p": p,
"seed": seed,
"stop_sequences": stop_sequences,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"raw_prompting": raw_prompting,
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[Tool], direction="write"
),
"tool_results": convert_and_respect_annotation_metadata(
object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write"
),
"force_single_step": force_single_step,
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormat, direction="write"
),
"safety_mode": safety_mode,
"stream": True,
},
headers={
"content-type": "application/json",
"Accepts": str(accepts) if accepts is not None else None,
},
request_options=request_options,
omit=OMIT,
) as _response:
async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[StreamedChatResponse]]:
try:
if 200 <= _response.status_code < 300:
async def _iter():
async for _text in _response.aiter_lines():
try:
if len(_text) == 0:
continue
yield typing.cast(
StreamedChatResponse,
construct_type(
type_=StreamedChatResponse, # type: ignore
object_=json.loads(_text),
),
)
except Exception:
pass
return
return AsyncHttpResponse(response=_response, data=_iter())
await _response.aread()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield await _stream()
async def chat(
self,
*,
message: str,
accepts: typing.Optional[typing.Literal["text/event-stream"]] = None,
model: typing.Optional[str] = OMIT,
preamble: typing.Optional[str] = OMIT,
chat_history: typing.Optional[typing.Sequence[Message]] = OMIT,
conversation_id: typing.Optional[str] = OMIT,
prompt_truncation: typing.Optional[ChatRequestPromptTruncation] = OMIT,
connectors: typing.Optional[typing.Sequence[ChatConnector]] = OMIT,
search_queries_only: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[ChatDocument]] = OMIT,
citation_quality: typing.Optional[ChatRequestCitationQuality] = OMIT,
temperature: typing.Optional[float] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
max_input_tokens: typing.Optional[int] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
tools: typing.Optional[typing.Sequence[Tool]] = OMIT,
tool_results: typing.Optional[typing.Sequence[ToolResult]] = OMIT,
force_single_step: typing.Optional[bool] = OMIT,
response_format: typing.Optional[ResponseFormat] = OMIT,
safety_mode: typing.Optional[ChatRequestSafetyMode] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[NonStreamedChatResponse]:
"""
Generates a text response to a user message.
To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api).
Parameters
----------
message : str
Text input for the model to respond to.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
accepts : typing.Optional[typing.Literal["text/event-stream"]]
Pass text/event-stream to receive the streamed response as server-sent events. The default is `\\n` delimited events.
model : typing.Optional[str]
The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model.
Compatible Deployments: Cohere Platform, Private Deployments
preamble : typing.Optional[str]
When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role.
The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
chat_history : typing.Optional[typing.Sequence[Message]]
A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`.
Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
conversation_id : typing.Optional[str]
An alternative to `chat_history`.
Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string.
Compatible Deployments: Cohere Platform
prompt_truncation : typing.Optional[ChatRequestPromptTruncation]
Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases.
Dictates how the prompt will be constructed.
With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance.
With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API.
With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned.
Compatible Deployments:
- AUTO: Cohere Platform Only
- AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments
connectors : typing.Optional[typing.Sequence[ChatConnector]]
Accepts `{"id": "web-search"}`, and/or the `"id"` for a custom [connector](https://docs.cohere.com/docs/connectors), if you've [created](https://docs.cohere.com/v1/docs/creating-and-deploying-a-connector) one.
When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG).
Compatible Deployments: Cohere Platform
search_queries_only : typing.Optional[bool]
Defaults to `false`.
When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
documents : typing.Optional[typing.Sequence[ChatDocument]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary.
Example:
```
[
{ "title": "Tall penguins", "text": "Emperor penguins are the tallest." },
{ "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." },
]
```
Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents.
Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words.
An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model.
An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model.
See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
citation_quality : typing.Optional[ChatRequestCitationQuality]
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
max_input_tokens : typing.Optional[int]
The maximum number of input tokens to send to the model. If not specified, `max_input_tokens` is the model's context length limit minus a small buffer.
Input will be truncated according to the `prompt_truncation` parameter.
Compatible Deployments: Cohere Platform
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without
any pre-processing.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tools : typing.Optional[typing.Sequence[Tool]]
A list of available tools (functions) that the model may suggest invoking before producing a text response.
When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
tool_results : typing.Optional[typing.Sequence[ToolResult]]
A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well.
Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries.
**Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list.
```
tool_results = [
{
"call": {
"name": ,
"parameters": {
:
}
},
"outputs": [{
:
}]
},
...
]
```
**Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
force_single_step : typing.Optional[bool]
Forces the chat to be single step. Defaults to `false`.
response_format : typing.Optional[ResponseFormat]
safety_mode : typing.Optional[ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `NONE` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[NonStreamedChatResponse]
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/chat",
method="POST",
json={
"message": message,
"model": model,
"preamble": preamble,
"chat_history": convert_and_respect_annotation_metadata(
object_=chat_history, annotation=typing.Sequence[Message], direction="write"
),
"conversation_id": conversation_id,
"prompt_truncation": prompt_truncation,
"connectors": convert_and_respect_annotation_metadata(
object_=connectors, annotation=typing.Sequence[ChatConnector], direction="write"
),
"search_queries_only": search_queries_only,
"documents": documents,
"citation_quality": citation_quality,
"temperature": temperature,
"max_tokens": max_tokens,
"max_input_tokens": max_input_tokens,
"k": k,
"p": p,
"seed": seed,
"stop_sequences": stop_sequences,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"raw_prompting": raw_prompting,
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[Tool], direction="write"
),
"tool_results": convert_and_respect_annotation_metadata(
object_=tool_results, annotation=typing.Sequence[ToolResult], direction="write"
),
"force_single_step": force_single_step,
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormat, direction="write"
),
"safety_mode": safety_mode,
"stream": False,
},
headers={
"content-type": "application/json",
"Accepts": str(accepts) if accepts is not None else None,
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
NonStreamedChatResponse,
construct_type(
type_=NonStreamedChatResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
@contextlib.asynccontextmanager
async def generate_stream(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateStreamRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateStreamRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat with Streaming API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateStreamRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateStreamRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]]
"""
async with self._client_wrapper.httpx_client.stream(
"v1/generate",
method="POST",
json={
"prompt": prompt,
"model": model,
"num_generations": num_generations,
"max_tokens": max_tokens,
"truncate": truncate,
"temperature": temperature,
"seed": seed,
"preset": preset,
"end_sequences": end_sequences,
"stop_sequences": stop_sequences,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"return_likelihoods": return_likelihoods,
"raw_prompting": raw_prompting,
"stream": True,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
) as _response:
async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[GenerateStreamedResponse]]:
try:
if 200 <= _response.status_code < 300:
async def _iter():
async for _text in _response.aiter_lines():
try:
if len(_text) == 0:
continue
yield typing.cast(
GenerateStreamedResponse,
construct_type(
type_=GenerateStreamedResponse, # type: ignore
object_=json.loads(_text),
),
)
except Exception:
pass
return
return AsyncHttpResponse(response=_response, data=_iter())
await _response.aread()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield await _stream()
async def generate(
self,
*,
prompt: str,
model: typing.Optional[str] = OMIT,
num_generations: typing.Optional[int] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
truncate: typing.Optional[GenerateRequestTruncate] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
preset: typing.Optional[str] = OMIT,
end_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
return_likelihoods: typing.Optional[GenerateRequestReturnLikelihoods] = OMIT,
raw_prompting: typing.Optional[bool] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[Generation]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates realistic text conditioned on a given input.
Parameters
----------
prompt : str
The input text that serves as the starting point for generating the response.
Note: The prompt will be pre-processed and modified before reaching the model.
model : typing.Optional[str]
The identifier of the model to generate with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental).
Smaller, "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
num_generations : typing.Optional[int]
The maximum number of generations that will be returned. Defaults to `1`, min value of `1`, max value of `5`.
max_tokens : typing.Optional[int]
The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
This parameter is off by default, and if it's not specified, the model will continue generating until it emits an EOS completion token. See [BPE Tokens](/bpe-tokens-wiki) for more details.
Can only be set to `0` if `return_likelihoods` is set to `ALL` to get the likelihood of the prompt.
truncate : typing.Optional[GenerateRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
temperature : typing.Optional[float]
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments
preset : typing.Optional[str]
Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.com/playground/generate).
When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters.
end_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
stop_sequences : typing.Optional[typing.Sequence[str]]
The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
k : typing.Optional[int]
Ensures only the top `k` most likely tokens are considered for generation at each step.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
frequency_penalty : typing.Optional[float]
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models.
return_likelihoods : typing.Optional[GenerateRequestReturnLikelihoods]
One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`.
If `GENERATION` is selected, the token likelihoods will only be provided for generated text.
WARNING: `ALL` is deprecated, and will be removed in a future release.
raw_prompting : typing.Optional[bool]
When enabled, the user's prompt will be sent to the model without any pre-processing.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[Generation]
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/generate",
method="POST",
json={
"prompt": prompt,
"model": model,
"num_generations": num_generations,
"max_tokens": max_tokens,
"truncate": truncate,
"temperature": temperature,
"seed": seed,
"preset": preset,
"end_sequences": end_sequences,
"stop_sequences": stop_sequences,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"return_likelihoods": return_likelihoods,
"raw_prompting": raw_prompting,
"stream": False,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
Generation,
construct_type(
type_=Generation, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def embed(
self,
*,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[EmbedResponse]:
"""
This endpoint returns text and image embeddings. An embedding is a list of floating point numbers that captures semantic information about the content that it represents.
Embeddings can be used to create classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Images are only supported with Embed v3.0 and newer models.
model : typing.Optional[str]
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : typing.Optional[EmbedInputType]
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[EmbedResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/embed",
method="POST",
json={
"texts": texts,
"images": images,
"model": model,
"input_type": input_type,
"embedding_types": embedding_types,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedResponse,
construct_type(
type_=EmbedResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def rerank(
self,
*,
query: str,
documents: typing.Sequence[RerankRequestDocumentsItem],
model: typing.Optional[str] = OMIT,
top_n: typing.Optional[int] = OMIT,
rank_fields: typing.Optional[typing.Sequence[str]] = OMIT,
return_documents: typing.Optional[bool] = OMIT,
max_chunks_per_doc: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[RerankResponse]:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
query : str
The search query
documents : typing.Sequence[RerankRequestDocumentsItem]
A list of document objects or strings to rerank.
If a document is provided the text fields is required and all other fields will be preserved in the response.
The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
rank_fields : typing.Optional[typing.Sequence[str]]
If a JSON object is provided, you can specify which keys you would like to have considered for reranking. The model will rerank based on order of the fields passed in (i.e. rank_fields=['title','author','text'] will rerank using the values in title, author, text sequentially. If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). If not provided, the model will use the default text field for ranking.
return_documents : typing.Optional[bool]
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
max_chunks_per_doc : typing.Optional[int]
The maximum number of chunks to produce internally from a document
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[RerankResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/rerank",
method="POST",
json={
"model": model,
"query": query,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[RerankRequestDocumentsItem], direction="write"
),
"top_n": top_n,
"rank_fields": rank_fields,
"return_documents": return_documents,
"max_chunks_per_doc": max_chunks_per_doc,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
RerankResponse,
construct_type(
type_=RerankResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def classify(
self,
*,
inputs: typing.Sequence[str],
examples: typing.Optional[typing.Sequence[ClassifyExample]] = OMIT,
model: typing.Optional[str] = OMIT,
preset: typing.Optional[str] = OMIT,
truncate: typing.Optional[ClassifyRequestTruncate] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[ClassifyResponse]:
"""
This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference.
Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
Parameters
----------
inputs : typing.Sequence[str]
A list of up to 96 texts to be classified. Each one must be a non-empty string.
There is, however, no consistent, universal limit to the length a particular input can be. We perform classification on the first `x` tokens of each input, and `x` varies depending on which underlying model is powering classification. The maximum token length for each model is listed in the "max tokens" column [here](https://docs.cohere.com/docs/models).
Note: by default the `truncate` parameter is set to `END`, so tokens exceeding the limit will be automatically dropped. This behavior can be disabled by setting `truncate` to `NONE`, which will result in validation errors for longer texts.
examples : typing.Optional[typing.Sequence[ClassifyExample]]
An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`.
Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly.
model : typing.Optional[str]
ID of a [Fine-tuned](https://docs.cohere.com/v2/docs/classify-starting-the-training) Classify model
preset : typing.Optional[str]
The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.com/playground). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
truncate : typing.Optional[ClassifyRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[ClassifyResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/classify",
method="POST",
json={
"inputs": inputs,
"examples": convert_and_respect_annotation_metadata(
object_=examples, annotation=typing.Sequence[ClassifyExample], direction="write"
),
"model": model,
"preset": preset,
"truncate": truncate,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
ClassifyResponse,
construct_type(
type_=ClassifyResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def summarize(
self,
*,
text: str,
length: typing.Optional[SummarizeRequestLength] = OMIT,
format: typing.Optional[SummarizeRequestFormat] = OMIT,
model: typing.Optional[str] = OMIT,
extractiveness: typing.Optional[SummarizeRequestExtractiveness] = OMIT,
temperature: typing.Optional[float] = OMIT,
additional_command: typing.Optional[str] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[SummarizeResponse]:
"""
This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API.
Generates a summary in English for a given text.
Parameters
----------
text : str
The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
length : typing.Optional[SummarizeRequestLength]
One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text.
format : typing.Optional[SummarizeRequestFormat]
One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text.
model : typing.Optional[str]
The identifier of the model to generate the summary with. Currently available models are `command` (default), `command-nightly` (experimental), `command-light`, and `command-light-nightly` (experimental). Smaller, "light" models are faster, while larger models will perform better.
extractiveness : typing.Optional[SummarizeRequestExtractiveness]
One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text.
temperature : typing.Optional[float]
Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
additional_command : typing.Optional[str]
A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[SummarizeResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/summarize",
method="POST",
json={
"text": text,
"length": length,
"format": format,
"model": model,
"extractiveness": extractiveness,
"temperature": temperature,
"additional_command": additional_command,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
SummarizeResponse,
construct_type(
type_=SummarizeResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def tokenize(
self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[TokenizeResponse]:
"""
This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
text : str
The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
model : str
The input will be tokenized by the tokenizer that is used by this model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[TokenizeResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/tokenize",
method="POST",
json={
"text": text,
"model": model,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
TokenizeResponse,
construct_type(
type_=TokenizeResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def detokenize(
self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[DetokenizeResponse]:
"""
This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page.
Parameters
----------
tokens : typing.Sequence[int]
The list of tokens to be detokenized.
model : str
An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[DetokenizeResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/detokenize",
method="POST",
json={
"tokens": tokens,
"model": model,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
DetokenizeResponse,
construct_type(
type_=DetokenizeResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def check_api_key(
self, *, request_options: typing.Optional[RequestOptions] = None
) -> AsyncHttpResponse[CheckApiKeyResponse]:
"""
Checks that the api key in the Authorization header is valid and active
Parameters
----------
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[CheckApiKeyResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v1/check-api-key",
method="POST",
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
CheckApiKeyResponse,
construct_type(
type_=CheckApiKeyResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/sagemaker_client.py
================================================
import typing
from .aws_client import AwsClient, AwsClientV2
from .manually_maintained.cohere_aws.client import Client
from .manually_maintained.cohere_aws.mode import Mode
class SagemakerClient(AwsClient):
sagemaker_finetuning: Client
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClient.__init__(
self,
service="sagemaker",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass
class SagemakerClientV2(AwsClientV2):
sagemaker_finetuning: Client
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="sagemaker",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass
================================================
FILE: src/cohere/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .api_meta import ApiMeta
from .api_meta_api_version import ApiMetaApiVersion
from .api_meta_billed_units import ApiMetaBilledUnits
from .api_meta_tokens import ApiMetaTokens
from .assistant_message import AssistantMessage
from .assistant_message_response import AssistantMessageResponse
from .assistant_message_response_content_item import (
AssistantMessageResponseContentItem,
TextAssistantMessageResponseContentItem,
ThinkingAssistantMessageResponseContentItem,
)
from .assistant_message_v2content import AssistantMessageV2Content
from .assistant_message_v2content_one_item import (
AssistantMessageV2ContentOneItem,
TextAssistantMessageV2ContentOneItem,
ThinkingAssistantMessageV2ContentOneItem,
)
from .auth_token_type import AuthTokenType
from .chat_citation import ChatCitation
from .chat_citation_generation_event import ChatCitationGenerationEvent
from .chat_citation_type import ChatCitationType
from .chat_connector import ChatConnector
from .chat_content_delta_event import ChatContentDeltaEvent
from .chat_content_delta_event_delta import ChatContentDeltaEventDelta
from .chat_content_delta_event_delta_message import ChatContentDeltaEventDeltaMessage
from .chat_content_delta_event_delta_message_content import ChatContentDeltaEventDeltaMessageContent
from .chat_content_end_event import ChatContentEndEvent
from .chat_content_start_event import ChatContentStartEvent
from .chat_content_start_event_delta import ChatContentStartEventDelta
from .chat_content_start_event_delta_message import ChatContentStartEventDeltaMessage
from .chat_content_start_event_delta_message_content import ChatContentStartEventDeltaMessageContent
from .chat_content_start_event_delta_message_content_type import ChatContentStartEventDeltaMessageContentType
from .chat_data_metrics import ChatDataMetrics
from .chat_debug_event import ChatDebugEvent
from .chat_document import ChatDocument
from .chat_document_source import ChatDocumentSource
from .chat_finish_reason import ChatFinishReason
from .chat_message import ChatMessage
from .chat_message_end_event import ChatMessageEndEvent
from .chat_message_end_event_delta import ChatMessageEndEventDelta
from .chat_message_start_event import ChatMessageStartEvent
from .chat_message_start_event_delta import ChatMessageStartEventDelta
from .chat_message_start_event_delta_message import ChatMessageStartEventDeltaMessage
from .chat_message_v2 import (
AssistantChatMessageV2,
ChatMessageV2,
SystemChatMessageV2,
ToolChatMessageV2,
UserChatMessageV2,
)
from .chat_messages import ChatMessages
from .chat_request_citation_quality import ChatRequestCitationQuality
from .chat_request_prompt_truncation import ChatRequestPromptTruncation
from .chat_request_safety_mode import ChatRequestSafetyMode
from .chat_search_queries_generation_event import ChatSearchQueriesGenerationEvent
from .chat_search_query import ChatSearchQuery
from .chat_search_result import ChatSearchResult
from .chat_search_result_connector import ChatSearchResultConnector
from .chat_search_results_event import ChatSearchResultsEvent
from .chat_stream_end_event import ChatStreamEndEvent
from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason
from .chat_stream_event import ChatStreamEvent
from .chat_stream_event_type import ChatStreamEventType
from .chat_stream_request_citation_quality import ChatStreamRequestCitationQuality
from .chat_stream_request_prompt_truncation import ChatStreamRequestPromptTruncation
from .chat_stream_request_safety_mode import ChatStreamRequestSafetyMode
from .chat_stream_start_event import ChatStreamStartEvent
from .chat_text_content import ChatTextContent
from .chat_text_generation_event import ChatTextGenerationEvent
from .chat_text_response_format import ChatTextResponseFormat
from .chat_text_response_format_v2 import ChatTextResponseFormatV2
from .chat_thinking_content import ChatThinkingContent
from .chat_tool_call_delta_event import ChatToolCallDeltaEvent
from .chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta
from .chat_tool_call_delta_event_delta_message import ChatToolCallDeltaEventDeltaMessage
from .chat_tool_call_delta_event_delta_message_tool_calls import ChatToolCallDeltaEventDeltaMessageToolCalls
from .chat_tool_call_delta_event_delta_message_tool_calls_function import (
ChatToolCallDeltaEventDeltaMessageToolCallsFunction,
)
from .chat_tool_call_end_event import ChatToolCallEndEvent
from .chat_tool_call_start_event import ChatToolCallStartEvent
from .chat_tool_call_start_event_delta import ChatToolCallStartEventDelta
from .chat_tool_call_start_event_delta_message import ChatToolCallStartEventDeltaMessage
from .chat_tool_calls_chunk_event import ChatToolCallsChunkEvent
from .chat_tool_calls_generation_event import ChatToolCallsGenerationEvent
from .chat_tool_message import ChatToolMessage
from .chat_tool_plan_delta_event import ChatToolPlanDeltaEvent
from .chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta
from .chat_tool_plan_delta_event_delta_message import ChatToolPlanDeltaEventDeltaMessage
from .chat_tool_source import ChatToolSource
from .check_api_key_response import CheckApiKeyResponse
from .citation import Citation
from .citation_end_event import CitationEndEvent
from .citation_options import CitationOptions
from .citation_options_mode import CitationOptionsMode
from .citation_start_event import CitationStartEvent
from .citation_start_event_delta import CitationStartEventDelta
from .citation_start_event_delta_message import CitationStartEventDeltaMessage
from .citation_type import CitationType
from .classify_data_metrics import ClassifyDataMetrics
from .classify_example import ClassifyExample
from .classify_request_truncate import ClassifyRequestTruncate
from .classify_response import ClassifyResponse
from .classify_response_classifications_item import ClassifyResponseClassificationsItem
from .classify_response_classifications_item_classification_type import (
ClassifyResponseClassificationsItemClassificationType,
)
from .classify_response_classifications_item_labels_value import ClassifyResponseClassificationsItemLabelsValue
from .compatible_endpoint import CompatibleEndpoint
from .connector import Connector
from .connector_auth_status import ConnectorAuthStatus
from .connector_o_auth import ConnectorOAuth
from .content import Content, ImageUrlContent, TextContent
from .create_connector_o_auth import CreateConnectorOAuth
from .create_connector_response import CreateConnectorResponse
from .create_connector_service_auth import CreateConnectorServiceAuth
from .create_embed_job_response import CreateEmbedJobResponse
from .dataset import Dataset
from .dataset_part import DatasetPart
from .dataset_type import DatasetType
from .dataset_validation_status import DatasetValidationStatus
from .delete_connector_response import DeleteConnectorResponse
from .detokenize_response import DetokenizeResponse
from .document import Document
from .document_content import DocumentContent
from .embed_by_type_response import EmbedByTypeResponse
from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings
from .embed_by_type_response_response_type import EmbedByTypeResponseResponseType
from .embed_content import EmbedContent, ImageUrlEmbedContent, TextEmbedContent
from .embed_floats_response import EmbedFloatsResponse
from .embed_image import EmbedImage
from .embed_image_url import EmbedImageUrl
from .embed_input import EmbedInput
from .embed_input_type import EmbedInputType
from .embed_job import EmbedJob
from .embed_job_status import EmbedJobStatus
from .embed_job_truncate import EmbedJobTruncate
from .embed_request_truncate import EmbedRequestTruncate
from .embed_response import EmbedResponse, EmbeddingsByTypeEmbedResponse, EmbeddingsFloatsEmbedResponse
from .embed_text import EmbedText
from .embedding_type import EmbeddingType
from .finetune_dataset_metrics import FinetuneDatasetMetrics
from .finish_reason import FinishReason
from .generate_request_return_likelihoods import GenerateRequestReturnLikelihoods
from .generate_request_truncate import GenerateRequestTruncate
from .generate_stream_end import GenerateStreamEnd
from .generate_stream_end_response import GenerateStreamEndResponse
from .generate_stream_error import GenerateStreamError
from .generate_stream_event import GenerateStreamEvent
from .generate_stream_request_return_likelihoods import GenerateStreamRequestReturnLikelihoods
from .generate_stream_request_truncate import GenerateStreamRequestTruncate
from .generate_stream_text import GenerateStreamText
from .generate_streamed_response import (
GenerateStreamedResponse,
StreamEndGenerateStreamedResponse,
StreamErrorGenerateStreamedResponse,
TextGenerationGenerateStreamedResponse,
)
from .generation import Generation
from .get_connector_response import GetConnectorResponse
from .get_model_response import GetModelResponse
from .get_model_response_sampling_defaults import GetModelResponseSamplingDefaults
from .image import Image
from .image_content import ImageContent
from .image_url import ImageUrl
from .image_url_detail import ImageUrlDetail
from .json_response_format import JsonResponseFormat
from .json_response_format_v2 import JsonResponseFormatV2
from .label_metric import LabelMetric
from .list_connectors_response import ListConnectorsResponse
from .list_embed_job_response import ListEmbedJobResponse
from .list_models_response import ListModelsResponse
from .logprob_item import LogprobItem
from .message import ChatbotMessage, Message, SystemMessage, ToolMessage, UserMessage
from .metrics import Metrics
from .non_streamed_chat_response import NonStreamedChatResponse
from .o_auth_authorize_response import OAuthAuthorizeResponse
from .parse_info import ParseInfo
from .rerank_document import RerankDocument
from .rerank_request_documents_item import RerankRequestDocumentsItem
from .rerank_response import RerankResponse
from .rerank_response_results_item import RerankResponseResultsItem
from .rerank_response_results_item_document import RerankResponseResultsItemDocument
from .reranker_data_metrics import RerankerDataMetrics
from .response_format import JsonObjectResponseFormat, ResponseFormat, TextResponseFormat
from .response_format_v2 import JsonObjectResponseFormatV2, ResponseFormatV2, TextResponseFormatV2
from .single_generation import SingleGeneration
from .single_generation_in_stream import SingleGenerationInStream
from .single_generation_token_likelihoods_item import SingleGenerationTokenLikelihoodsItem
from .source import DocumentSource, Source, ToolSource
from .streamed_chat_response import (
CitationGenerationStreamedChatResponse,
DebugStreamedChatResponse,
SearchQueriesGenerationStreamedChatResponse,
SearchResultsStreamedChatResponse,
StreamEndStreamedChatResponse,
StreamStartStreamedChatResponse,
StreamedChatResponse,
TextGenerationStreamedChatResponse,
ToolCallsChunkStreamedChatResponse,
ToolCallsGenerationStreamedChatResponse,
)
from .summarize_request_extractiveness import SummarizeRequestExtractiveness
from .summarize_request_format import SummarizeRequestFormat
from .summarize_request_length import SummarizeRequestLength
from .summarize_response import SummarizeResponse
from .system_message_v2 import SystemMessageV2
from .system_message_v2content import SystemMessageV2Content
from .system_message_v2content_one_item import SystemMessageV2ContentOneItem, TextSystemMessageV2ContentOneItem
from .thinking import Thinking
from .thinking_type import ThinkingType
from .tokenize_response import TokenizeResponse
from .tool import Tool
from .tool_call import ToolCall
from .tool_call_delta import ToolCallDelta
from .tool_call_v2 import ToolCallV2
from .tool_call_v2function import ToolCallV2Function
from .tool_content import DocumentToolContent, TextToolContent, ToolContent
from .tool_message_v2 import ToolMessageV2
from .tool_message_v2content import ToolMessageV2Content
from .tool_parameter_definitions_value import ToolParameterDefinitionsValue
from .tool_result import ToolResult
from .tool_v2 import ToolV2
from .tool_v2function import ToolV2Function
from .update_connector_response import UpdateConnectorResponse
from .usage import Usage
from .usage_billed_units import UsageBilledUnits
from .usage_tokens import UsageTokens
from .user_message_v2 import UserMessageV2
from .user_message_v2content import UserMessageV2Content
_dynamic_imports: typing.Dict[str, str] = {
"ApiMeta": ".api_meta",
"ApiMetaApiVersion": ".api_meta_api_version",
"ApiMetaBilledUnits": ".api_meta_billed_units",
"ApiMetaTokens": ".api_meta_tokens",
"AssistantChatMessageV2": ".chat_message_v2",
"AssistantMessage": ".assistant_message",
"AssistantMessageResponse": ".assistant_message_response",
"AssistantMessageResponseContentItem": ".assistant_message_response_content_item",
"AssistantMessageV2Content": ".assistant_message_v2content",
"AssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item",
"AuthTokenType": ".auth_token_type",
"ChatCitation": ".chat_citation",
"ChatCitationGenerationEvent": ".chat_citation_generation_event",
"ChatCitationType": ".chat_citation_type",
"ChatConnector": ".chat_connector",
"ChatContentDeltaEvent": ".chat_content_delta_event",
"ChatContentDeltaEventDelta": ".chat_content_delta_event_delta",
"ChatContentDeltaEventDeltaMessage": ".chat_content_delta_event_delta_message",
"ChatContentDeltaEventDeltaMessageContent": ".chat_content_delta_event_delta_message_content",
"ChatContentEndEvent": ".chat_content_end_event",
"ChatContentStartEvent": ".chat_content_start_event",
"ChatContentStartEventDelta": ".chat_content_start_event_delta",
"ChatContentStartEventDeltaMessage": ".chat_content_start_event_delta_message",
"ChatContentStartEventDeltaMessageContent": ".chat_content_start_event_delta_message_content",
"ChatContentStartEventDeltaMessageContentType": ".chat_content_start_event_delta_message_content_type",
"ChatDataMetrics": ".chat_data_metrics",
"ChatDebugEvent": ".chat_debug_event",
"ChatDocument": ".chat_document",
"ChatDocumentSource": ".chat_document_source",
"ChatFinishReason": ".chat_finish_reason",
"ChatMessage": ".chat_message",
"ChatMessageEndEvent": ".chat_message_end_event",
"ChatMessageEndEventDelta": ".chat_message_end_event_delta",
"ChatMessageStartEvent": ".chat_message_start_event",
"ChatMessageStartEventDelta": ".chat_message_start_event_delta",
"ChatMessageStartEventDeltaMessage": ".chat_message_start_event_delta_message",
"ChatMessageV2": ".chat_message_v2",
"ChatMessages": ".chat_messages",
"ChatRequestCitationQuality": ".chat_request_citation_quality",
"ChatRequestPromptTruncation": ".chat_request_prompt_truncation",
"ChatRequestSafetyMode": ".chat_request_safety_mode",
"ChatSearchQueriesGenerationEvent": ".chat_search_queries_generation_event",
"ChatSearchQuery": ".chat_search_query",
"ChatSearchResult": ".chat_search_result",
"ChatSearchResultConnector": ".chat_search_result_connector",
"ChatSearchResultsEvent": ".chat_search_results_event",
"ChatStreamEndEvent": ".chat_stream_end_event",
"ChatStreamEndEventFinishReason": ".chat_stream_end_event_finish_reason",
"ChatStreamEvent": ".chat_stream_event",
"ChatStreamEventType": ".chat_stream_event_type",
"ChatStreamRequestCitationQuality": ".chat_stream_request_citation_quality",
"ChatStreamRequestPromptTruncation": ".chat_stream_request_prompt_truncation",
"ChatStreamRequestSafetyMode": ".chat_stream_request_safety_mode",
"ChatStreamStartEvent": ".chat_stream_start_event",
"ChatTextContent": ".chat_text_content",
"ChatTextGenerationEvent": ".chat_text_generation_event",
"ChatTextResponseFormat": ".chat_text_response_format",
"ChatTextResponseFormatV2": ".chat_text_response_format_v2",
"ChatThinkingContent": ".chat_thinking_content",
"ChatToolCallDeltaEvent": ".chat_tool_call_delta_event",
"ChatToolCallDeltaEventDelta": ".chat_tool_call_delta_event_delta",
"ChatToolCallDeltaEventDeltaMessage": ".chat_tool_call_delta_event_delta_message",
"ChatToolCallDeltaEventDeltaMessageToolCalls": ".chat_tool_call_delta_event_delta_message_tool_calls",
"ChatToolCallDeltaEventDeltaMessageToolCallsFunction": ".chat_tool_call_delta_event_delta_message_tool_calls_function",
"ChatToolCallEndEvent": ".chat_tool_call_end_event",
"ChatToolCallStartEvent": ".chat_tool_call_start_event",
"ChatToolCallStartEventDelta": ".chat_tool_call_start_event_delta",
"ChatToolCallStartEventDeltaMessage": ".chat_tool_call_start_event_delta_message",
"ChatToolCallsChunkEvent": ".chat_tool_calls_chunk_event",
"ChatToolCallsGenerationEvent": ".chat_tool_calls_generation_event",
"ChatToolMessage": ".chat_tool_message",
"ChatToolPlanDeltaEvent": ".chat_tool_plan_delta_event",
"ChatToolPlanDeltaEventDelta": ".chat_tool_plan_delta_event_delta",
"ChatToolPlanDeltaEventDeltaMessage": ".chat_tool_plan_delta_event_delta_message",
"ChatToolSource": ".chat_tool_source",
"ChatbotMessage": ".message",
"CheckApiKeyResponse": ".check_api_key_response",
"Citation": ".citation",
"CitationEndEvent": ".citation_end_event",
"CitationGenerationStreamedChatResponse": ".streamed_chat_response",
"CitationOptions": ".citation_options",
"CitationOptionsMode": ".citation_options_mode",
"CitationStartEvent": ".citation_start_event",
"CitationStartEventDelta": ".citation_start_event_delta",
"CitationStartEventDeltaMessage": ".citation_start_event_delta_message",
"CitationType": ".citation_type",
"ClassifyDataMetrics": ".classify_data_metrics",
"ClassifyExample": ".classify_example",
"ClassifyRequestTruncate": ".classify_request_truncate",
"ClassifyResponse": ".classify_response",
"ClassifyResponseClassificationsItem": ".classify_response_classifications_item",
"ClassifyResponseClassificationsItemClassificationType": ".classify_response_classifications_item_classification_type",
"ClassifyResponseClassificationsItemLabelsValue": ".classify_response_classifications_item_labels_value",
"CompatibleEndpoint": ".compatible_endpoint",
"Connector": ".connector",
"ConnectorAuthStatus": ".connector_auth_status",
"ConnectorOAuth": ".connector_o_auth",
"Content": ".content",
"CreateConnectorOAuth": ".create_connector_o_auth",
"CreateConnectorResponse": ".create_connector_response",
"CreateConnectorServiceAuth": ".create_connector_service_auth",
"CreateEmbedJobResponse": ".create_embed_job_response",
"Dataset": ".dataset",
"DatasetPart": ".dataset_part",
"DatasetType": ".dataset_type",
"DatasetValidationStatus": ".dataset_validation_status",
"DebugStreamedChatResponse": ".streamed_chat_response",
"DeleteConnectorResponse": ".delete_connector_response",
"DetokenizeResponse": ".detokenize_response",
"Document": ".document",
"DocumentContent": ".document_content",
"DocumentSource": ".source",
"DocumentToolContent": ".tool_content",
"EmbedByTypeResponse": ".embed_by_type_response",
"EmbedByTypeResponseEmbeddings": ".embed_by_type_response_embeddings",
"EmbedByTypeResponseResponseType": ".embed_by_type_response_response_type",
"EmbedContent": ".embed_content",
"EmbedFloatsResponse": ".embed_floats_response",
"EmbedImage": ".embed_image",
"EmbedImageUrl": ".embed_image_url",
"EmbedInput": ".embed_input",
"EmbedInputType": ".embed_input_type",
"EmbedJob": ".embed_job",
"EmbedJobStatus": ".embed_job_status",
"EmbedJobTruncate": ".embed_job_truncate",
"EmbedRequestTruncate": ".embed_request_truncate",
"EmbedResponse": ".embed_response",
"EmbedText": ".embed_text",
"EmbeddingType": ".embedding_type",
"EmbeddingsByTypeEmbedResponse": ".embed_response",
"EmbeddingsFloatsEmbedResponse": ".embed_response",
"FinetuneDatasetMetrics": ".finetune_dataset_metrics",
"FinishReason": ".finish_reason",
"GenerateRequestReturnLikelihoods": ".generate_request_return_likelihoods",
"GenerateRequestTruncate": ".generate_request_truncate",
"GenerateStreamEnd": ".generate_stream_end",
"GenerateStreamEndResponse": ".generate_stream_end_response",
"GenerateStreamError": ".generate_stream_error",
"GenerateStreamEvent": ".generate_stream_event",
"GenerateStreamRequestReturnLikelihoods": ".generate_stream_request_return_likelihoods",
"GenerateStreamRequestTruncate": ".generate_stream_request_truncate",
"GenerateStreamText": ".generate_stream_text",
"GenerateStreamedResponse": ".generate_streamed_response",
"Generation": ".generation",
"GetConnectorResponse": ".get_connector_response",
"GetModelResponse": ".get_model_response",
"GetModelResponseSamplingDefaults": ".get_model_response_sampling_defaults",
"Image": ".image",
"ImageContent": ".image_content",
"ImageUrl": ".image_url",
"ImageUrlContent": ".content",
"ImageUrlDetail": ".image_url_detail",
"ImageUrlEmbedContent": ".embed_content",
"JsonObjectResponseFormat": ".response_format",
"JsonObjectResponseFormatV2": ".response_format_v2",
"JsonResponseFormat": ".json_response_format",
"JsonResponseFormatV2": ".json_response_format_v2",
"LabelMetric": ".label_metric",
"ListConnectorsResponse": ".list_connectors_response",
"ListEmbedJobResponse": ".list_embed_job_response",
"ListModelsResponse": ".list_models_response",
"LogprobItem": ".logprob_item",
"Message": ".message",
"Metrics": ".metrics",
"NonStreamedChatResponse": ".non_streamed_chat_response",
"OAuthAuthorizeResponse": ".o_auth_authorize_response",
"ParseInfo": ".parse_info",
"RerankDocument": ".rerank_document",
"RerankRequestDocumentsItem": ".rerank_request_documents_item",
"RerankResponse": ".rerank_response",
"RerankResponseResultsItem": ".rerank_response_results_item",
"RerankResponseResultsItemDocument": ".rerank_response_results_item_document",
"RerankerDataMetrics": ".reranker_data_metrics",
"ResponseFormat": ".response_format",
"ResponseFormatV2": ".response_format_v2",
"SearchQueriesGenerationStreamedChatResponse": ".streamed_chat_response",
"SearchResultsStreamedChatResponse": ".streamed_chat_response",
"SingleGeneration": ".single_generation",
"SingleGenerationInStream": ".single_generation_in_stream",
"SingleGenerationTokenLikelihoodsItem": ".single_generation_token_likelihoods_item",
"Source": ".source",
"StreamEndGenerateStreamedResponse": ".generate_streamed_response",
"StreamEndStreamedChatResponse": ".streamed_chat_response",
"StreamErrorGenerateStreamedResponse": ".generate_streamed_response",
"StreamStartStreamedChatResponse": ".streamed_chat_response",
"StreamedChatResponse": ".streamed_chat_response",
"SummarizeRequestExtractiveness": ".summarize_request_extractiveness",
"SummarizeRequestFormat": ".summarize_request_format",
"SummarizeRequestLength": ".summarize_request_length",
"SummarizeResponse": ".summarize_response",
"SystemChatMessageV2": ".chat_message_v2",
"SystemMessage": ".message",
"SystemMessageV2": ".system_message_v2",
"SystemMessageV2Content": ".system_message_v2content",
"SystemMessageV2ContentOneItem": ".system_message_v2content_one_item",
"TextAssistantMessageResponseContentItem": ".assistant_message_response_content_item",
"TextAssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item",
"TextContent": ".content",
"TextEmbedContent": ".embed_content",
"TextGenerationGenerateStreamedResponse": ".generate_streamed_response",
"TextGenerationStreamedChatResponse": ".streamed_chat_response",
"TextResponseFormat": ".response_format",
"TextResponseFormatV2": ".response_format_v2",
"TextSystemMessageV2ContentOneItem": ".system_message_v2content_one_item",
"TextToolContent": ".tool_content",
"Thinking": ".thinking",
"ThinkingAssistantMessageResponseContentItem": ".assistant_message_response_content_item",
"ThinkingAssistantMessageV2ContentOneItem": ".assistant_message_v2content_one_item",
"ThinkingType": ".thinking_type",
"TokenizeResponse": ".tokenize_response",
"Tool": ".tool",
"ToolCall": ".tool_call",
"ToolCallDelta": ".tool_call_delta",
"ToolCallV2": ".tool_call_v2",
"ToolCallV2Function": ".tool_call_v2function",
"ToolCallsChunkStreamedChatResponse": ".streamed_chat_response",
"ToolCallsGenerationStreamedChatResponse": ".streamed_chat_response",
"ToolChatMessageV2": ".chat_message_v2",
"ToolContent": ".tool_content",
"ToolMessage": ".message",
"ToolMessageV2": ".tool_message_v2",
"ToolMessageV2Content": ".tool_message_v2content",
"ToolParameterDefinitionsValue": ".tool_parameter_definitions_value",
"ToolResult": ".tool_result",
"ToolSource": ".source",
"ToolV2": ".tool_v2",
"ToolV2Function": ".tool_v2function",
"UpdateConnectorResponse": ".update_connector_response",
"Usage": ".usage",
"UsageBilledUnits": ".usage_billed_units",
"UsageTokens": ".usage_tokens",
"UserChatMessageV2": ".chat_message_v2",
"UserMessage": ".message",
"UserMessageV2": ".user_message_v2",
"UserMessageV2Content": ".user_message_v2content",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"ApiMeta",
"ApiMetaApiVersion",
"ApiMetaBilledUnits",
"ApiMetaTokens",
"AssistantChatMessageV2",
"AssistantMessage",
"AssistantMessageResponse",
"AssistantMessageResponseContentItem",
"AssistantMessageV2Content",
"AssistantMessageV2ContentOneItem",
"AuthTokenType",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatCitationType",
"ChatConnector",
"ChatContentDeltaEvent",
"ChatContentDeltaEventDelta",
"ChatContentDeltaEventDeltaMessage",
"ChatContentDeltaEventDeltaMessageContent",
"ChatContentEndEvent",
"ChatContentStartEvent",
"ChatContentStartEventDelta",
"ChatContentStartEventDeltaMessage",
"ChatContentStartEventDeltaMessageContent",
"ChatContentStartEventDeltaMessageContentType",
"ChatDataMetrics",
"ChatDebugEvent",
"ChatDocument",
"ChatDocumentSource",
"ChatFinishReason",
"ChatMessage",
"ChatMessageEndEvent",
"ChatMessageEndEventDelta",
"ChatMessageStartEvent",
"ChatMessageStartEventDelta",
"ChatMessageStartEventDeltaMessage",
"ChatMessageV2",
"ChatMessages",
"ChatRequestCitationQuality",
"ChatRequestPromptTruncation",
"ChatRequestSafetyMode",
"ChatSearchQueriesGenerationEvent",
"ChatSearchQuery",
"ChatSearchResult",
"ChatSearchResultConnector",
"ChatSearchResultsEvent",
"ChatStreamEndEvent",
"ChatStreamEndEventFinishReason",
"ChatStreamEvent",
"ChatStreamEventType",
"ChatStreamRequestCitationQuality",
"ChatStreamRequestPromptTruncation",
"ChatStreamRequestSafetyMode",
"ChatStreamStartEvent",
"ChatTextContent",
"ChatTextGenerationEvent",
"ChatTextResponseFormat",
"ChatTextResponseFormatV2",
"ChatThinkingContent",
"ChatToolCallDeltaEvent",
"ChatToolCallDeltaEventDelta",
"ChatToolCallDeltaEventDeltaMessage",
"ChatToolCallDeltaEventDeltaMessageToolCalls",
"ChatToolCallDeltaEventDeltaMessageToolCallsFunction",
"ChatToolCallEndEvent",
"ChatToolCallStartEvent",
"ChatToolCallStartEventDelta",
"ChatToolCallStartEventDeltaMessage",
"ChatToolCallsChunkEvent",
"ChatToolCallsGenerationEvent",
"ChatToolMessage",
"ChatToolPlanDeltaEvent",
"ChatToolPlanDeltaEventDelta",
"ChatToolPlanDeltaEventDeltaMessage",
"ChatToolSource",
"ChatbotMessage",
"CheckApiKeyResponse",
"Citation",
"CitationEndEvent",
"CitationGenerationStreamedChatResponse",
"CitationOptions",
"CitationOptionsMode",
"CitationStartEvent",
"CitationStartEventDelta",
"CitationStartEventDeltaMessage",
"CitationType",
"ClassifyDataMetrics",
"ClassifyExample",
"ClassifyRequestTruncate",
"ClassifyResponse",
"ClassifyResponseClassificationsItem",
"ClassifyResponseClassificationsItemClassificationType",
"ClassifyResponseClassificationsItemLabelsValue",
"CompatibleEndpoint",
"Connector",
"ConnectorAuthStatus",
"ConnectorOAuth",
"Content",
"CreateConnectorOAuth",
"CreateConnectorResponse",
"CreateConnectorServiceAuth",
"CreateEmbedJobResponse",
"Dataset",
"DatasetPart",
"DatasetType",
"DatasetValidationStatus",
"DebugStreamedChatResponse",
"DeleteConnectorResponse",
"DetokenizeResponse",
"Document",
"DocumentContent",
"DocumentSource",
"DocumentToolContent",
"EmbedByTypeResponse",
"EmbedByTypeResponseEmbeddings",
"EmbedByTypeResponseResponseType",
"EmbedContent",
"EmbedFloatsResponse",
"EmbedImage",
"EmbedImageUrl",
"EmbedInput",
"EmbedInputType",
"EmbedJob",
"EmbedJobStatus",
"EmbedJobTruncate",
"EmbedRequestTruncate",
"EmbedResponse",
"EmbedText",
"EmbeddingType",
"EmbeddingsByTypeEmbedResponse",
"EmbeddingsFloatsEmbedResponse",
"FinetuneDatasetMetrics",
"FinishReason",
"GenerateRequestReturnLikelihoods",
"GenerateRequestTruncate",
"GenerateStreamEnd",
"GenerateStreamEndResponse",
"GenerateStreamError",
"GenerateStreamEvent",
"GenerateStreamRequestReturnLikelihoods",
"GenerateStreamRequestTruncate",
"GenerateStreamText",
"GenerateStreamedResponse",
"Generation",
"GetConnectorResponse",
"GetModelResponse",
"GetModelResponseSamplingDefaults",
"Image",
"ImageContent",
"ImageUrl",
"ImageUrlContent",
"ImageUrlDetail",
"ImageUrlEmbedContent",
"JsonObjectResponseFormat",
"JsonObjectResponseFormatV2",
"JsonResponseFormat",
"JsonResponseFormatV2",
"LabelMetric",
"ListConnectorsResponse",
"ListEmbedJobResponse",
"ListModelsResponse",
"LogprobItem",
"Message",
"Metrics",
"NonStreamedChatResponse",
"OAuthAuthorizeResponse",
"ParseInfo",
"RerankDocument",
"RerankRequestDocumentsItem",
"RerankResponse",
"RerankResponseResultsItem",
"RerankResponseResultsItemDocument",
"RerankerDataMetrics",
"ResponseFormat",
"ResponseFormatV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"SingleGeneration",
"SingleGenerationInStream",
"SingleGenerationTokenLikelihoodsItem",
"Source",
"StreamEndGenerateStreamedResponse",
"StreamEndStreamedChatResponse",
"StreamErrorGenerateStreamedResponse",
"StreamStartStreamedChatResponse",
"StreamedChatResponse",
"SummarizeRequestExtractiveness",
"SummarizeRequestFormat",
"SummarizeRequestLength",
"SummarizeResponse",
"SystemChatMessageV2",
"SystemMessage",
"SystemMessageV2",
"SystemMessageV2Content",
"SystemMessageV2ContentOneItem",
"TextAssistantMessageResponseContentItem",
"TextAssistantMessageV2ContentOneItem",
"TextContent",
"TextEmbedContent",
"TextGenerationGenerateStreamedResponse",
"TextGenerationStreamedChatResponse",
"TextResponseFormat",
"TextResponseFormatV2",
"TextSystemMessageV2ContentOneItem",
"TextToolContent",
"Thinking",
"ThinkingAssistantMessageResponseContentItem",
"ThinkingAssistantMessageV2ContentOneItem",
"ThinkingType",
"TokenizeResponse",
"Tool",
"ToolCall",
"ToolCallDelta",
"ToolCallV2",
"ToolCallV2Function",
"ToolCallsChunkStreamedChatResponse",
"ToolCallsGenerationStreamedChatResponse",
"ToolChatMessageV2",
"ToolContent",
"ToolMessage",
"ToolMessageV2",
"ToolMessageV2Content",
"ToolParameterDefinitionsValue",
"ToolResult",
"ToolSource",
"ToolV2",
"ToolV2Function",
"UpdateConnectorResponse",
"Usage",
"UsageBilledUnits",
"UsageTokens",
"UserChatMessageV2",
"UserMessage",
"UserMessageV2",
"UserMessageV2Content",
]
================================================
FILE: src/cohere/types/api_meta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta_api_version import ApiMetaApiVersion
from .api_meta_billed_units import ApiMetaBilledUnits
from .api_meta_tokens import ApiMetaTokens
class ApiMeta(UncheckedBaseModel):
api_version: typing.Optional[ApiMetaApiVersion] = None
billed_units: typing.Optional[ApiMetaBilledUnits] = None
tokens: typing.Optional[ApiMetaTokens] = None
cached_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of prompt tokens that hit the inference cache.
"""
warnings: typing.Optional[typing.List[str]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/api_meta_api_version.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ApiMetaApiVersion(UncheckedBaseModel):
version: str
is_deprecated: typing.Optional[bool] = None
is_experimental: typing.Optional[bool] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/api_meta_billed_units.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ApiMetaBilledUnits(UncheckedBaseModel):
images: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed images.
"""
input_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed input tokens.
"""
image_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed image tokens.
"""
output_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed output tokens.
"""
search_units: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed search units.
"""
classifications: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed classifications units.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/api_meta_tokens.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ApiMetaTokens(UncheckedBaseModel):
input_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of tokens used as input to the model.
"""
output_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of tokens produced by the model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/assistant_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .assistant_message_v2content import AssistantMessageV2Content
from .citation import Citation
from .tool_call_v2 import ToolCallV2
class AssistantMessage(UncheckedBaseModel):
"""
A message from the assistant role can contain text and tool call information.
"""
tool_calls: typing.Optional[typing.List[ToolCallV2]] = None
tool_plan: typing.Optional[str] = pydantic.Field(default=None)
"""
A chain-of-thought style reflection and plan that the model generates when working with Tools.
"""
content: typing.Optional[AssistantMessageV2Content] = None
citations: typing.Optional[typing.List[Citation]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/assistant_message_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .assistant_message_response_content_item import AssistantMessageResponseContentItem
from .citation import Citation
from .tool_call_v2 import ToolCallV2
class AssistantMessageResponse(UncheckedBaseModel):
"""
A message from the assistant role can contain text and tool call information.
"""
role: typing.Literal["assistant"] = "assistant"
tool_calls: typing.Optional[typing.List[ToolCallV2]] = None
tool_plan: typing.Optional[str] = pydantic.Field(default=None)
"""
A chain-of-thought style reflection and plan that the model generates when working with Tools.
"""
content: typing.Optional[typing.List[AssistantMessageResponseContentItem]] = None
citations: typing.Optional[typing.List[Citation]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/assistant_message_response_content_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
class TextAssistantMessageResponseContentItem(UncheckedBaseModel):
type: typing.Literal["text"] = "text"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ThinkingAssistantMessageResponseContentItem(UncheckedBaseModel):
type: typing.Literal["thinking"] = "thinking"
thinking: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
AssistantMessageResponseContentItem = typing_extensions.Annotated[
typing.Union[TextAssistantMessageResponseContentItem, ThinkingAssistantMessageResponseContentItem],
UnionMetadata(discriminant="type"),
]
================================================
FILE: src/cohere/types/assistant_message_v2content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .assistant_message_v2content_one_item import AssistantMessageV2ContentOneItem
AssistantMessageV2Content = typing.Union[str, typing.List[AssistantMessageV2ContentOneItem]]
================================================
FILE: src/cohere/types/assistant_message_v2content_one_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
class TextAssistantMessageV2ContentOneItem(UncheckedBaseModel):
type: typing.Literal["text"] = "text"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ThinkingAssistantMessageV2ContentOneItem(UncheckedBaseModel):
type: typing.Literal["thinking"] = "thinking"
thinking: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
AssistantMessageV2ContentOneItem = typing_extensions.Annotated[
typing.Union[TextAssistantMessageV2ContentOneItem, ThinkingAssistantMessageV2ContentOneItem],
UnionMetadata(discriminant="type"),
]
================================================
FILE: src/cohere/types/auth_token_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
AuthTokenType = typing.Union[typing.Literal["bearer", "basic", "noscheme"], typing.Any]
================================================
FILE: src/cohere/types/chat_citation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_citation_type import ChatCitationType
class ChatCitation(UncheckedBaseModel):
"""
A section of the generated reply which cites external knowledge.
"""
start: int = pydantic.Field()
"""
The index of text that the citation starts at, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have a start value of `7`. This is because the citation starts at `w`, which is the seventh character.
"""
end: int = pydantic.Field()
"""
The index of text that the citation ends after, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have an end value of `11`. This is because the citation ends after `d`, which is the eleventh character.
"""
text: str = pydantic.Field()
"""
The text of the citation. For example, a generation of `Hello, world!` with a citation of `world` would have a text value of `world`.
"""
document_ids: typing.List[str] = pydantic.Field()
"""
Identifiers of documents cited by this section of the generated reply.
"""
type: typing.Optional[ChatCitationType] = pydantic.Field(default=None)
"""
The type of citation which indicates what part of the response the citation is for.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_citation_generation_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_citation import ChatCitation
from .chat_stream_event import ChatStreamEvent
class ChatCitationGenerationEvent(ChatStreamEvent):
citations: typing.List[ChatCitation] = pydantic.Field()
"""
Citations for the generated reply.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_citation_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatCitationType = typing.Union[typing.Literal["TEXT_CONTENT", "PLAN"], typing.Any]
================================================
FILE: src/cohere/types/chat_connector.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatConnector(UncheckedBaseModel):
"""
The connector used for fetching documents.
"""
id: str = pydantic.Field()
"""
The identifier of the connector.
"""
user_access_token: typing.Optional[str] = pydantic.Field(default=None)
"""
When specified, this user access token will be passed to the connector in the Authorization header instead of the Cohere generated one.
"""
continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None)
"""
Defaults to `false`.
When `true`, the request will continue if this connector returned an error.
"""
options: typing.Optional[typing.Dict[str, typing.Any]] = pydantic.Field(default=None)
"""
Provides the connector with different settings at request time. The key/value pairs of this object are specific to each connector.
For example, the connector `web-search` supports the `site` option, which limits search results to the specified domain.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_delta_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_content_delta_event_delta import ChatContentDeltaEventDelta
from .chat_stream_event_type import ChatStreamEventType
from .logprob_item import LogprobItem
class ChatContentDeltaEvent(ChatStreamEventType):
"""
A streamed delta event which contains a delta of chat text content.
"""
index: typing.Optional[int] = None
delta: typing.Optional[ChatContentDeltaEventDelta] = None
logprobs: typing.Optional[LogprobItem] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_delta_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_content_delta_event_delta_message import ChatContentDeltaEventDeltaMessage
class ChatContentDeltaEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatContentDeltaEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_delta_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_content_delta_event_delta_message_content import ChatContentDeltaEventDeltaMessageContent
class ChatContentDeltaEventDeltaMessage(UncheckedBaseModel):
content: typing.Optional[ChatContentDeltaEventDeltaMessageContent] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_delta_event_delta_message_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatContentDeltaEventDeltaMessageContent(UncheckedBaseModel):
thinking: typing.Optional[str] = None
text: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_end_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
class ChatContentEndEvent(ChatStreamEventType):
"""
A streamed delta event which signifies that the content block has ended.
"""
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_start_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_content_start_event_delta import ChatContentStartEventDelta
from .chat_stream_event_type import ChatStreamEventType
class ChatContentStartEvent(ChatStreamEventType):
"""
A streamed delta event which signifies that a new content block has started.
"""
index: typing.Optional[int] = None
delta: typing.Optional[ChatContentStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_start_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_content_start_event_delta_message import ChatContentStartEventDeltaMessage
class ChatContentStartEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatContentStartEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_start_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_content_start_event_delta_message_content import ChatContentStartEventDeltaMessageContent
class ChatContentStartEventDeltaMessage(UncheckedBaseModel):
content: typing.Optional[ChatContentStartEventDeltaMessageContent] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_start_event_delta_message_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_content_start_event_delta_message_content_type import ChatContentStartEventDeltaMessageContentType
class ChatContentStartEventDeltaMessageContent(UncheckedBaseModel):
thinking: typing.Optional[str] = None
text: typing.Optional[str] = None
type: typing.Optional[ChatContentStartEventDeltaMessageContentType] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_content_start_event_delta_message_content_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatContentStartEventDeltaMessageContentType = typing.Union[typing.Literal["text", "thinking"], typing.Any]
================================================
FILE: src/cohere/types/chat_data_metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatDataMetrics(UncheckedBaseModel):
num_train_turns: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all turns of valid train examples.
"""
num_eval_turns: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all turns of valid eval examples.
"""
preamble: typing.Optional[str] = pydantic.Field(default=None)
"""
The preamble of this dataset.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_debug_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event import ChatStreamEvent
class ChatDebugEvent(ChatStreamEvent):
prompt: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_document.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatDocument = typing.Dict[str, str]
"""
Relevant information that could be used by the model to generate a more accurate reply.
The contents of each document are generally short (under 300 words), and are passed in the form of a
dictionary of strings. Some suggested keys are "text", "author", "date". Both the key name and the value will be
passed to the model.
"""
================================================
FILE: src/cohere/types/chat_document_source.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatDocumentSource(UncheckedBaseModel):
"""
A document source object containing the unique identifier of the document and the document itself.
"""
id: typing.Optional[str] = pydantic.Field(default=None)
"""
The unique identifier of the document
"""
document: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_finish_reason.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatFinishReason = typing.Union[
typing.Literal["COMPLETE", "STOP_SEQUENCE", "MAX_TOKENS", "TOOL_CALL", "ERROR", "TIMEOUT"], typing.Any
]
================================================
FILE: src/cohere/types/chat_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_call import ToolCall
class ChatMessage(UncheckedBaseModel):
"""
Represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content.
The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used.
"""
message: str = pydantic.Field()
"""
Contents of the chat message.
"""
tool_calls: typing.Optional[typing.List[ToolCall]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_end_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_message_end_event_delta import ChatMessageEndEventDelta
from .chat_stream_event_type import ChatStreamEventType
class ChatMessageEndEvent(ChatStreamEventType):
"""
A streamed event which signifies that the chat message has ended.
"""
id: typing.Optional[str] = None
delta: typing.Optional[ChatMessageEndEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_end_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_finish_reason import ChatFinishReason
from .usage import Usage
class ChatMessageEndEventDelta(UncheckedBaseModel):
error: typing.Optional[str] = pydantic.Field(default=None)
"""
An error message if an error occurred during the generation.
"""
finish_reason: typing.Optional[ChatFinishReason] = None
usage: typing.Optional[Usage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_start_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_message_start_event_delta import ChatMessageStartEventDelta
from .chat_stream_event_type import ChatStreamEventType
class ChatMessageStartEvent(ChatStreamEventType):
"""
A streamed event which signifies that a stream has started.
"""
id: typing.Optional[str] = pydantic.Field(default=None)
"""
Unique identifier for the generated reply.
"""
delta: typing.Optional[ChatMessageStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_start_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_message_start_event_delta_message import ChatMessageStartEventDeltaMessage
class ChatMessageStartEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatMessageStartEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_start_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatMessageStartEventDeltaMessage(UncheckedBaseModel):
role: typing.Optional[typing.Literal["assistant"]] = pydantic.Field(default=None)
"""
The role of the message.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_message_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .assistant_message_v2content import AssistantMessageV2Content
from .citation import Citation
from .system_message_v2content import SystemMessageV2Content
from .tool_call_v2 import ToolCallV2
from .tool_message_v2content import ToolMessageV2Content
from .user_message_v2content import UserMessageV2Content
class UserChatMessageV2(UncheckedBaseModel):
"""
Represents a single message in the chat history from a given role.
"""
role: typing.Literal["user"] = "user"
content: UserMessageV2Content
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class AssistantChatMessageV2(UncheckedBaseModel):
"""
Represents a single message in the chat history from a given role.
"""
role: typing.Literal["assistant"] = "assistant"
tool_calls: typing.Optional[typing.List[ToolCallV2]] = None
tool_plan: typing.Optional[str] = None
content: typing.Optional[AssistantMessageV2Content] = None
citations: typing.Optional[typing.List[Citation]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class SystemChatMessageV2(UncheckedBaseModel):
"""
Represents a single message in the chat history from a given role.
"""
role: typing.Literal["system"] = "system"
content: SystemMessageV2Content
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolChatMessageV2(UncheckedBaseModel):
"""
Represents a single message in the chat history from a given role.
"""
role: typing.Literal["tool"] = "tool"
tool_call_id: str
content: ToolMessageV2Content
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
ChatMessageV2 = typing_extensions.Annotated[
typing.Union[UserChatMessageV2, AssistantChatMessageV2, SystemChatMessageV2, ToolChatMessageV2],
UnionMetadata(discriminant="role"),
]
================================================
FILE: src/cohere/types/chat_messages.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .chat_message_v2 import ChatMessageV2
ChatMessages = typing.List[ChatMessageV2]
"""
A list of chat messages in chronological order, representing a conversation between the user and the model.
Messages can be from `User`, `Assistant`, `Tool` and `System` roles. Learn more about messages and roles in [the Chat API guide](https://docs.cohere.com/v2/docs/chat-api).
"""
================================================
FILE: src/cohere/types/chat_request_citation_quality.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatRequestCitationQuality = typing.Union[typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any]
================================================
FILE: src/cohere/types/chat_request_prompt_truncation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatRequestPromptTruncation = typing.Union[typing.Literal["OFF", "AUTO", "AUTO_PRESERVE_ORDER"], typing.Any]
================================================
FILE: src/cohere/types/chat_request_safety_mode.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "NONE"], typing.Any]
================================================
FILE: src/cohere/types/chat_search_queries_generation_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_search_query import ChatSearchQuery
from .chat_stream_event import ChatStreamEvent
class ChatSearchQueriesGenerationEvent(ChatStreamEvent):
search_queries: typing.List[ChatSearchQuery] = pydantic.Field()
"""
Generated search queries, meant to be used as part of the RAG flow.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_search_query.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatSearchQuery(UncheckedBaseModel):
"""
The generated search query. Contains the text of the query and a unique identifier for the query.
"""
text: str = pydantic.Field()
"""
The text of the search query.
"""
generation_id: str = pydantic.Field()
"""
Unique identifier for the generated search query. Useful for submitting feedback.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_search_result.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_search_query import ChatSearchQuery
from .chat_search_result_connector import ChatSearchResultConnector
class ChatSearchResult(UncheckedBaseModel):
search_query: typing.Optional[ChatSearchQuery] = None
connector: ChatSearchResultConnector = pydantic.Field()
"""
The connector from which this result comes from.
"""
document_ids: typing.List[str] = pydantic.Field()
"""
Identifiers of documents found by this search query.
"""
error_message: typing.Optional[str] = pydantic.Field(default=None)
"""
An error message if the search failed.
"""
continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None)
"""
Whether a chat request should continue or not if the request to this connector fails.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_search_result_connector.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatSearchResultConnector(UncheckedBaseModel):
"""
The connector used for fetching documents.
"""
id: str = pydantic.Field()
"""
The identifier of the connector.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_search_results_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_document import ChatDocument
from .chat_search_result import ChatSearchResult
from .chat_stream_event import ChatStreamEvent
class ChatSearchResultsEvent(ChatStreamEvent):
search_results: typing.Optional[typing.List[ChatSearchResult]] = pydantic.Field(default=None)
"""
Conducted searches and the ids of documents retrieved from each of them.
"""
documents: typing.Optional[typing.List[ChatDocument]] = pydantic.Field(default=None)
"""
Documents fetched from searches or provided by the user.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_stream_end_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason
from .chat_stream_event import ChatStreamEvent
from .non_streamed_chat_response import NonStreamedChatResponse
class ChatStreamEndEvent(ChatStreamEvent):
finish_reason: ChatStreamEndEventFinishReason = pydantic.Field()
"""
- `COMPLETE` - the model sent back a finished reply
- `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length
- `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter
- `ERROR` - something went wrong when generating the reply
- `ERROR_TOXIC` - the model generated a reply that was deemed toxic
"""
response: NonStreamedChatResponse = pydantic.Field()
"""
The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_stream_end_event_finish_reason.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatStreamEndEventFinishReason = typing.Union[
typing.Literal["COMPLETE", "ERROR_LIMIT", "MAX_TOKENS", "ERROR", "ERROR_TOXIC"], typing.Any
]
================================================
FILE: src/cohere/types/chat_stream_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatStreamEvent(UncheckedBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_stream_event_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatStreamEventType(UncheckedBaseModel):
"""
The streamed event types
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_stream_request_citation_quality.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatStreamRequestCitationQuality = typing.Union[
typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any
]
================================================
FILE: src/cohere/types/chat_stream_request_prompt_truncation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatStreamRequestPromptTruncation = typing.Union[typing.Literal["OFF", "AUTO", "AUTO_PRESERVE_ORDER"], typing.Any]
================================================
FILE: src/cohere/types/chat_stream_request_safety_mode.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ChatStreamRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "NONE"], typing.Any]
================================================
FILE: src/cohere/types/chat_stream_start_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event import ChatStreamEvent
class ChatStreamStartEvent(ChatStreamEvent):
generation_id: str = pydantic.Field()
"""
Unique identifier for the generated reply. Useful for submitting feedback.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_text_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatTextContent(UncheckedBaseModel):
"""
Text content of the message.
"""
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_text_generation_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event import ChatStreamEvent
class ChatTextGenerationEvent(ChatStreamEvent):
text: str = pydantic.Field()
"""
The next batch of text generated by the model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_text_response_format.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatTextResponseFormat(UncheckedBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_text_response_format_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatTextResponseFormatV2(UncheckedBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_thinking_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatThinkingContent(UncheckedBaseModel):
"""
Thinking content of the message. This will be present when `thinking` is enabled, and will contain the models internal reasoning.
"""
thinking: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_delta_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
from .chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta
class ChatToolCallDeltaEvent(ChatStreamEventType):
"""
A streamed event delta which signifies a delta in tool call arguments.
"""
index: typing.Optional[int] = None
delta: typing.Optional[ChatToolCallDeltaEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_delta_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_tool_call_delta_event_delta_message import ChatToolCallDeltaEventDeltaMessage
class ChatToolCallDeltaEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatToolCallDeltaEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_delta_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_tool_call_delta_event_delta_message_tool_calls import ChatToolCallDeltaEventDeltaMessageToolCalls
class ChatToolCallDeltaEventDeltaMessage(UncheckedBaseModel):
tool_calls: typing.Optional[ChatToolCallDeltaEventDeltaMessageToolCalls] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_delta_event_delta_message_tool_calls.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_tool_call_delta_event_delta_message_tool_calls_function import (
ChatToolCallDeltaEventDeltaMessageToolCallsFunction,
)
class ChatToolCallDeltaEventDeltaMessageToolCalls(UncheckedBaseModel):
function: typing.Optional[ChatToolCallDeltaEventDeltaMessageToolCallsFunction] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_delta_event_delta_message_tool_calls_function.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatToolCallDeltaEventDeltaMessageToolCallsFunction(UncheckedBaseModel):
arguments: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_end_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
class ChatToolCallEndEvent(ChatStreamEventType):
"""
A streamed event delta which signifies a tool call has finished streaming.
"""
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_start_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
from .chat_tool_call_start_event_delta import ChatToolCallStartEventDelta
class ChatToolCallStartEvent(ChatStreamEventType):
"""
A streamed event delta which signifies a tool call has started streaming.
"""
index: typing.Optional[int] = None
delta: typing.Optional[ChatToolCallStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_start_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_tool_call_start_event_delta_message import ChatToolCallStartEventDeltaMessage
class ChatToolCallStartEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatToolCallStartEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_call_start_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_call_v2 import ToolCallV2
class ChatToolCallStartEventDeltaMessage(UncheckedBaseModel):
tool_calls: typing.Optional[ToolCallV2] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_calls_chunk_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event import ChatStreamEvent
from .tool_call_delta import ToolCallDelta
class ChatToolCallsChunkEvent(ChatStreamEvent):
tool_call_delta: ToolCallDelta
text: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_calls_generation_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event import ChatStreamEvent
from .tool_call import ToolCall
class ChatToolCallsGenerationEvent(ChatStreamEvent):
text: typing.Optional[str] = pydantic.Field(default=None)
"""
The text generated related to the tool calls generated
"""
tool_calls: typing.List[ToolCall]
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_result import ToolResult
class ChatToolMessage(UncheckedBaseModel):
"""
Represents tool result in the chat history.
"""
tool_results: typing.Optional[typing.List[ToolResult]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_plan_delta_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
from .chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta
class ChatToolPlanDeltaEvent(ChatStreamEventType):
"""
A streamed event which contains a delta of tool plan text.
"""
delta: typing.Optional[ChatToolPlanDeltaEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_plan_delta_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .chat_tool_plan_delta_event_delta_message import ChatToolPlanDeltaEventDeltaMessage
class ChatToolPlanDeltaEventDelta(UncheckedBaseModel):
message: typing.Optional[ChatToolPlanDeltaEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_plan_delta_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatToolPlanDeltaEventDeltaMessage(UncheckedBaseModel):
tool_plan: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/chat_tool_source.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ChatToolSource(UncheckedBaseModel):
id: typing.Optional[str] = pydantic.Field(default=None)
"""
The unique identifier of the document
"""
tool_output: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/check_api_key_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class CheckApiKeyResponse(UncheckedBaseModel):
valid: bool
organization_id: typing.Optional[str] = None
owner_id: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .citation_type import CitationType
from .source import Source
class Citation(UncheckedBaseModel):
"""
Citation information containing sources and the text cited.
"""
start: typing.Optional[int] = pydantic.Field(default=None)
"""
Start index of the cited snippet in the original source text.
"""
end: typing.Optional[int] = pydantic.Field(default=None)
"""
End index of the cited snippet in the original source text.
"""
text: typing.Optional[str] = pydantic.Field(default=None)
"""
Text snippet that is being cited.
"""
sources: typing.Optional[typing.List[Source]] = None
content_index: typing.Optional[int] = pydantic.Field(default=None)
"""
Index of the content block in which this citation appears.
"""
type: typing.Optional[CitationType] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_end_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
class CitationEndEvent(ChatStreamEventType):
"""
A streamed event which signifies a citation has finished streaming.
"""
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_options.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .citation_options_mode import CitationOptionsMode
class CitationOptions(UncheckedBaseModel):
"""
Options for controlling citation generation.
"""
mode: typing.Optional[CitationOptionsMode] = pydantic.Field(default=None)
"""
Defaults to `"enabled"`.
Citations are enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_options_mode.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
CitationOptionsMode = typing.Union[typing.Literal["ENABLED", "DISABLED", "FAST", "ACCURATE", "OFF"], typing.Any]
================================================
FILE: src/cohere/types/citation_start_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .chat_stream_event_type import ChatStreamEventType
from .citation_start_event_delta import CitationStartEventDelta
class CitationStartEvent(ChatStreamEventType):
"""
A streamed event which signifies a citation has been created.
"""
index: typing.Optional[int] = None
delta: typing.Optional[CitationStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_start_event_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .citation_start_event_delta_message import CitationStartEventDeltaMessage
class CitationStartEventDelta(UncheckedBaseModel):
message: typing.Optional[CitationStartEventDeltaMessage] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_start_event_delta_message.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .citation import Citation
class CitationStartEventDeltaMessage(UncheckedBaseModel):
citations: typing.Optional[Citation] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/citation_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
CitationType = typing.Union[typing.Literal["TEXT_CONTENT", "THINKING_CONTENT", "PLAN"], typing.Any]
================================================
FILE: src/cohere/types/classify_data_metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .label_metric import LabelMetric
class ClassifyDataMetrics(UncheckedBaseModel):
label_metrics: typing.Optional[typing.List[LabelMetric]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/classify_example.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ClassifyExample(UncheckedBaseModel):
text: typing.Optional[str] = None
label: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/classify_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ClassifyRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any]
================================================
FILE: src/cohere/types/classify_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .classify_response_classifications_item import ClassifyResponseClassificationsItem
class ClassifyResponse(UncheckedBaseModel):
id: str
classifications: typing.List[ClassifyResponseClassificationsItem]
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/classify_response_classifications_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .classify_response_classifications_item_classification_type import (
ClassifyResponseClassificationsItemClassificationType,
)
from .classify_response_classifications_item_labels_value import ClassifyResponseClassificationsItemLabelsValue
class ClassifyResponseClassificationsItem(UncheckedBaseModel):
id: str
input: typing.Optional[str] = pydantic.Field(default=None)
"""
The input text that was classified
"""
prediction: typing.Optional[str] = pydantic.Field(default=None)
"""
The predicted label for the associated query (only filled for single-label models)
"""
predictions: typing.List[str] = pydantic.Field()
"""
An array containing the predicted labels for the associated query (only filled for single-label classification)
"""
confidence: typing.Optional[float] = pydantic.Field(default=None)
"""
The confidence score for the top predicted class (only filled for single-label classification)
"""
confidences: typing.List[float] = pydantic.Field()
"""
An array containing the confidence scores of all the predictions in the same order
"""
labels: typing.Dict[str, ClassifyResponseClassificationsItemLabelsValue] = pydantic.Field()
"""
A map containing each label and its confidence score according to the classifier. All the confidence scores add up to 1 for single-label classification. For multi-label classification the label confidences are independent of each other, so they don't have to sum up to 1.
"""
classification_type: ClassifyResponseClassificationsItemClassificationType = pydantic.Field()
"""
The type of classification performed
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/classify_response_classifications_item_classification_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ClassifyResponseClassificationsItemClassificationType = typing.Union[
typing.Literal["single-label", "multi-label"], typing.Any
]
================================================
FILE: src/cohere/types/classify_response_classifications_item_labels_value.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ClassifyResponseClassificationsItemLabelsValue(UncheckedBaseModel):
confidence: typing.Optional[float] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/compatible_endpoint.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
CompatibleEndpoint = typing.Union[
typing.Literal["chat", "embed", "classify", "summarize", "rerank", "rate", "generate"], typing.Any
]
================================================
FILE: src/cohere/types/connector.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .connector_auth_status import ConnectorAuthStatus
from .connector_o_auth import ConnectorOAuth
class Connector(UncheckedBaseModel):
"""
A connector allows you to integrate data sources with the '/chat' endpoint to create grounded generations with citations to the data source.
documents to help answer users.
"""
id: str = pydantic.Field()
"""
The unique identifier of the connector (used in both `/connectors` & `/chat` endpoints).
This is automatically created from the name of the connector upon registration.
"""
organization_id: typing.Optional[str] = pydantic.Field(default=None)
"""
The organization to which this connector belongs. This is automatically set to
the organization of the user who created the connector.
"""
name: str = pydantic.Field()
"""
A human-readable name for the connector.
"""
description: typing.Optional[str] = pydantic.Field(default=None)
"""
A description of the connector.
"""
url: typing.Optional[str] = pydantic.Field(default=None)
"""
The URL of the connector that will be used to search for documents.
"""
created_at: dt.datetime = pydantic.Field()
"""
The UTC time at which the connector was created.
"""
updated_at: dt.datetime = pydantic.Field()
"""
The UTC time at which the connector was last updated.
"""
excludes: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
A list of fields to exclude from the prompt (fields remain in the document).
"""
auth_type: typing.Optional[str] = pydantic.Field(default=None)
"""
The type of authentication/authorization used by the connector. Possible values: [oauth, service_auth]
"""
oauth: typing.Optional[ConnectorOAuth] = pydantic.Field(default=None)
"""
The OAuth 2.0 configuration for the connector.
"""
auth_status: typing.Optional[ConnectorAuthStatus] = pydantic.Field(default=None)
"""
The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet.
"""
active: typing.Optional[bool] = pydantic.Field(default=None)
"""
Whether the connector is active or not.
"""
continue_on_failure: typing.Optional[bool] = pydantic.Field(default=None)
"""
Whether a chat request should continue or not if the request to this connector fails.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/connector_auth_status.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ConnectorAuthStatus = typing.Union[typing.Literal["valid", "expired"], typing.Any]
================================================
FILE: src/cohere/types/connector_o_auth.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ConnectorOAuth(UncheckedBaseModel):
client_id: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 client ID. This field is encrypted at rest.
"""
client_secret: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response.
"""
authorize_url: str = pydantic.Field()
"""
The OAuth 2.0 /authorize endpoint to use when users authorize the connector.
"""
token_url: str = pydantic.Field()
"""
The OAuth 2.0 /token endpoint to use when users authorize the connector.
"""
scope: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth scopes to request when users authorize the connector.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/content.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .image_url import ImageUrl
class TextContent(UncheckedBaseModel):
"""
A Content block which contains information about the content type and the content itself.
"""
type: typing.Literal["text"] = "text"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ImageUrlContent(UncheckedBaseModel):
"""
A Content block which contains information about the content type and the content itself.
"""
type: typing.Literal["image_url"] = "image_url"
image_url: ImageUrl
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
Content = typing_extensions.Annotated[typing.Union[TextContent, ImageUrlContent], UnionMetadata(discriminant="type")]
================================================
FILE: src/cohere/types/create_connector_o_auth.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class CreateConnectorOAuth(UncheckedBaseModel):
client_id: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 client ID. This fields is encrypted at rest.
"""
client_secret: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response.
"""
authorize_url: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 /authorize endpoint to use when users authorize the connector.
"""
token_url: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 /token endpoint to use when users authorize the connector.
"""
scope: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth scopes to request when users authorize the connector.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/create_connector_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .connector import Connector
class CreateConnectorResponse(UncheckedBaseModel):
connector: Connector
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/create_connector_service_auth.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .auth_token_type import AuthTokenType
class CreateConnectorServiceAuth(UncheckedBaseModel):
type: AuthTokenType
token: str = pydantic.Field()
"""
The token that will be used in the HTTP Authorization header when making requests to the connector. This field is encrypted at rest and never returned in a response.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/create_embed_job_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
class CreateEmbedJobResponse(UncheckedBaseModel):
"""
Response from creating an embed job.
"""
job_id: str
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/dataset.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.serialization import FieldMetadata
from ..core.unchecked_base_model import UncheckedBaseModel
from .dataset_part import DatasetPart
from .dataset_type import DatasetType
from .dataset_validation_status import DatasetValidationStatus
class Dataset(UncheckedBaseModel):
id: str = pydantic.Field()
"""
The dataset ID
"""
name: str = pydantic.Field()
"""
The name of the dataset
"""
created_at: dt.datetime = pydantic.Field()
"""
The creation date
"""
updated_at: dt.datetime = pydantic.Field()
"""
The last update date
"""
dataset_type: DatasetType
validation_status: DatasetValidationStatus
validation_error: typing.Optional[str] = pydantic.Field(default=None)
"""
Errors found during validation
"""
schema_: typing_extensions.Annotated[
typing.Optional[str],
FieldMetadata(alias="schema"),
pydantic.Field(alias="schema", description="the avro schema of the dataset"),
] = None
required_fields: typing.Optional[typing.List[str]] = None
preserve_fields: typing.Optional[typing.List[str]] = None
dataset_parts: typing.Optional[typing.List[DatasetPart]] = pydantic.Field(default=None)
"""
the underlying files that make up the dataset
"""
validation_warnings: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
warnings found during validation
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/dataset_part.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class DatasetPart(UncheckedBaseModel):
id: str = pydantic.Field()
"""
The dataset part ID
"""
name: str = pydantic.Field()
"""
The name of the dataset part
"""
url: typing.Optional[str] = pydantic.Field(default=None)
"""
The download url of the file
"""
index: typing.Optional[int] = pydantic.Field(default=None)
"""
The index of the file
"""
size_bytes: typing.Optional[int] = pydantic.Field(default=None)
"""
The size of the file in bytes
"""
num_rows: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of rows in the file
"""
original_url: typing.Optional[str] = pydantic.Field(default=None)
"""
The download url of the original file
"""
samples: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
The first few rows of the parsed file
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/dataset_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
DatasetType = typing.Union[
typing.Literal[
"embed-input",
"embed-result",
"cluster-result",
"cluster-outliers",
"reranker-finetune-input",
"single-label-classification-finetune-input",
"chat-finetune-input",
"multi-label-classification-finetune-input",
"batch-chat-input",
"batch-openai-chat-input",
"batch-embed-v2-input",
"batch-chat-v2-input",
],
typing.Any,
]
================================================
FILE: src/cohere/types/dataset_validation_status.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
DatasetValidationStatus = typing.Union[
typing.Literal["unknown", "queued", "processing", "failed", "validated", "skipped"], typing.Any
]
================================================
FILE: src/cohere/types/delete_connector_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
DeleteConnectorResponse = typing.Dict[str, typing.Any]
================================================
FILE: src/cohere/types/detokenize_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
class DetokenizeResponse(UncheckedBaseModel):
text: str = pydantic.Field()
"""
A string representing the list of tokens.
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/document.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class Document(UncheckedBaseModel):
"""
Relevant information that could be used by the model to generate a more accurate reply.
The content of each document are generally short (should be under 300 words). Metadata should be used to provide additional information, both the key name and the value will be
passed to the model.
"""
data: typing.Dict[str, typing.Any] = pydantic.Field()
"""
A relevant document that the model can cite to generate a more accurate reply. Each document is a string-any dictionary.
"""
id: typing.Optional[str] = pydantic.Field(default=None)
"""
Unique identifier for this document which will be referenced in citations. If not provided an ID will be automatically generated.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/document_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .document import Document
class DocumentContent(UncheckedBaseModel):
"""
Document content.
"""
document: Document
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_by_type_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings
from .embed_by_type_response_response_type import EmbedByTypeResponseResponseType
from .image import Image
class EmbedByTypeResponse(UncheckedBaseModel):
response_type: typing.Optional[EmbedByTypeResponseResponseType] = None
id: str
embeddings: EmbedByTypeResponseEmbeddings = pydantic.Field()
"""
An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array.
"""
texts: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
The text entries for which embeddings were returned.
"""
images: typing.Optional[typing.List[Image]] = pydantic.Field(default=None)
"""
The image entries for which embeddings were returned.
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_by_type_response_embeddings.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.serialization import FieldMetadata
from ..core.unchecked_base_model import UncheckedBaseModel
class EmbedByTypeResponseEmbeddings(UncheckedBaseModel):
"""
An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array.
"""
float_: typing_extensions.Annotated[
typing.Optional[typing.List[typing.List[float]]],
FieldMetadata(alias="float"),
pydantic.Field(alias="float", description="An array of float embeddings."),
] = None
int8: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None)
"""
An array of signed int8 embeddings. Each value is between -128 and 127.
"""
uint8: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None)
"""
An array of unsigned int8 embeddings. Each value is between 0 and 255.
"""
binary: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None)
"""
An array of packed signed binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between -128 and 127.
"""
ubinary: typing.Optional[typing.List[typing.List[int]]] = pydantic.Field(default=None)
"""
An array of packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between 0 and 255.
"""
base64: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
An array of base64 embeddings. Each string is the result of appending the float embedding bytes together and base64 encoding that.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_by_type_response_response_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbedByTypeResponseResponseType = typing.Union[typing.Literal["embeddings_floats", "embeddings_by_type"], typing.Any]
================================================
FILE: src/cohere/types/embed_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .embed_image_url import EmbedImageUrl
class ImageUrlEmbedContent(UncheckedBaseModel):
type: typing.Literal["image_url"] = "image_url"
image_url: typing.Optional[EmbedImageUrl] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class TextEmbedContent(UncheckedBaseModel):
type: typing.Literal["text"] = "text"
text: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
EmbedContent = typing_extensions.Annotated[
typing.Union[ImageUrlEmbedContent, TextEmbedContent], UnionMetadata(discriminant="type")
]
================================================
FILE: src/cohere/types/embed_floats_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .image import Image
class EmbedFloatsResponse(UncheckedBaseModel):
id: str
embeddings: typing.List[typing.List[float]] = pydantic.Field()
"""
An array of embeddings, where each embedding is an array of floats. The length of the `embeddings` array will be the same as the length of the original `texts` array.
"""
texts: typing.List[str] = pydantic.Field()
"""
The text entries for which embeddings were returned.
"""
images: typing.Optional[typing.List[Image]] = pydantic.Field(default=None)
"""
The image entries for which embeddings were returned.
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_image.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .embed_image_url import EmbedImageUrl
class EmbedImage(UncheckedBaseModel):
"""
Image content of the input. Supported with Embed v3.0 and newer models.
"""
image_url: typing.Optional[EmbedImageUrl] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_image_url.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class EmbedImageUrl(UncheckedBaseModel):
"""
Base64 url of image.
"""
url: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_input.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .embed_content import EmbedContent
class EmbedInput(UncheckedBaseModel):
content: typing.List[EmbedContent] = pydantic.Field()
"""
An array of objects containing the input data for the model to embed.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_input_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbedInputType = typing.Union[
typing.Literal["search_document", "search_query", "classification", "clustering", "image"], typing.Any
]
================================================
FILE: src/cohere/types/embed_job.py
================================================
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .embed_job_status import EmbedJobStatus
from .embed_job_truncate import EmbedJobTruncate
class EmbedJob(UncheckedBaseModel):
job_id: str = pydantic.Field()
"""
ID of the embed job
"""
name: typing.Optional[str] = pydantic.Field(default=None)
"""
The name of the embed job
"""
status: EmbedJobStatus = pydantic.Field()
"""
The status of the embed job
"""
created_at: dt.datetime = pydantic.Field()
"""
The creation date of the embed job
"""
input_dataset_id: str = pydantic.Field()
"""
ID of the input dataset
"""
output_dataset_id: typing.Optional[str] = pydantic.Field(default=None)
"""
ID of the resulting output dataset
"""
model: str = pydantic.Field()
"""
ID of the model used to embed
"""
truncate: EmbedJobTruncate = pydantic.Field()
"""
The truncation option used
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embed_job_status.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbedJobStatus = typing.Union[typing.Literal["processing", "complete", "cancelling", "cancelled", "failed"], typing.Any]
================================================
FILE: src/cohere/types/embed_job_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbedJobTruncate = typing.Union[typing.Literal["START", "END"], typing.Any]
================================================
FILE: src/cohere/types/embed_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbedRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any]
================================================
FILE: src/cohere/types/embed_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .api_meta import ApiMeta
from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings
from .image import Image
class EmbeddingsFloatsEmbedResponse(UncheckedBaseModel):
response_type: typing.Literal["embeddings_floats"] = "embeddings_floats"
id: str
embeddings: typing.List[typing.List[float]]
texts: typing.List[str]
images: typing.Optional[typing.List[Image]] = None
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class EmbeddingsByTypeEmbedResponse(UncheckedBaseModel):
response_type: typing.Literal["embeddings_by_type"] = "embeddings_by_type"
id: str
embeddings: EmbedByTypeResponseEmbeddings
texts: typing.Optional[typing.List[str]] = None
images: typing.Optional[typing.List[Image]] = None
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
EmbedResponse = typing_extensions.Annotated[
typing.Union[EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse],
UnionMetadata(discriminant="response_type"),
]
================================================
FILE: src/cohere/types/embed_text.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class EmbedText(UncheckedBaseModel):
"""
Text content of the input.
"""
text: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/embedding_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
EmbeddingType = typing.Union[typing.Literal["float", "int8", "uint8", "binary", "ubinary", "base64"], typing.Any]
================================================
FILE: src/cohere/types/finetune_dataset_metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class FinetuneDatasetMetrics(UncheckedBaseModel):
trainable_token_count: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of tokens of valid examples that can be used for training.
"""
total_examples: typing.Optional[int] = pydantic.Field(default=None)
"""
The overall number of examples.
"""
train_examples: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of training examples.
"""
train_size_bytes: typing.Optional[int] = pydantic.Field(default=None)
"""
The size in bytes of all training examples.
"""
eval_examples: typing.Optional[int] = pydantic.Field(default=None)
"""
Number of evaluation examples.
"""
eval_size_bytes: typing.Optional[int] = pydantic.Field(default=None)
"""
The size in bytes of all eval examples.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/finish_reason.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
FinishReason = typing.Union[
typing.Literal[
"COMPLETE", "STOP_SEQUENCE", "ERROR", "ERROR_TOXIC", "ERROR_LIMIT", "USER_CANCEL", "MAX_TOKENS", "TIMEOUT"
],
typing.Any,
]
================================================
FILE: src/cohere/types/generate_request_return_likelihoods.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
GenerateRequestReturnLikelihoods = typing.Union[typing.Literal["GENERATION", "ALL", "NONE"], typing.Any]
================================================
FILE: src/cohere/types/generate_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
GenerateRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any]
================================================
FILE: src/cohere/types/generate_stream_end.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .finish_reason import FinishReason
from .generate_stream_end_response import GenerateStreamEndResponse
from .generate_stream_event import GenerateStreamEvent
class GenerateStreamEnd(GenerateStreamEvent):
is_finished: bool
finish_reason: typing.Optional[FinishReason] = None
response: GenerateStreamEndResponse
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/generate_stream_end_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .single_generation_in_stream import SingleGenerationInStream
class GenerateStreamEndResponse(UncheckedBaseModel):
id: str
prompt: typing.Optional[str] = None
generations: typing.Optional[typing.List[SingleGenerationInStream]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/generate_stream_error.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .finish_reason import FinishReason
from .generate_stream_event import GenerateStreamEvent
class GenerateStreamError(GenerateStreamEvent):
index: typing.Optional[int] = pydantic.Field(default=None)
"""
Refers to the nth generation. Only present when `num_generations` is greater than zero.
"""
is_finished: bool
finish_reason: FinishReason
err: str = pydantic.Field()
"""
Error message
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/generate_stream_event.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class GenerateStreamEvent(UncheckedBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/generate_stream_request_return_likelihoods.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
GenerateStreamRequestReturnLikelihoods = typing.Union[typing.Literal["GENERATION", "ALL", "NONE"], typing.Any]
================================================
FILE: src/cohere/types/generate_stream_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
GenerateStreamRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any]
================================================
FILE: src/cohere/types/generate_stream_text.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from .generate_stream_event import GenerateStreamEvent
class GenerateStreamText(GenerateStreamEvent):
text: str = pydantic.Field()
"""
A segment of text of the generation.
"""
index: typing.Optional[int] = pydantic.Field(default=None)
"""
Refers to the nth generation. Only present when `num_generations` is greater than zero, and only when text responses are being streamed.
"""
is_finished: bool
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/generate_streamed_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .finish_reason import FinishReason
from .generate_stream_end_response import GenerateStreamEndResponse
class TextGenerationGenerateStreamedResponse(UncheckedBaseModel):
"""
Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse.
"""
event_type: typing.Literal["text-generation"] = "text-generation"
text: str
index: typing.Optional[int] = None
is_finished: bool
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class StreamEndGenerateStreamedResponse(UncheckedBaseModel):
"""
Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse.
"""
event_type: typing.Literal["stream-end"] = "stream-end"
is_finished: bool
finish_reason: typing.Optional[FinishReason] = None
response: GenerateStreamEndResponse
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class StreamErrorGenerateStreamedResponse(UncheckedBaseModel):
"""
Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse.
"""
event_type: typing.Literal["stream-error"] = "stream-error"
index: typing.Optional[int] = None
is_finished: bool
finish_reason: FinishReason
err: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
GenerateStreamedResponse = typing_extensions.Annotated[
typing.Union[
TextGenerationGenerateStreamedResponse, StreamEndGenerateStreamedResponse, StreamErrorGenerateStreamedResponse
],
UnionMetadata(discriminant="event_type"),
]
================================================
FILE: src/cohere/types/generation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .single_generation import SingleGeneration
class Generation(UncheckedBaseModel):
id: str
prompt: typing.Optional[str] = pydantic.Field(default=None)
"""
Prompt used for generations.
"""
generations: typing.List[SingleGeneration] = pydantic.Field()
"""
List of generated results
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/get_connector_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .connector import Connector
class GetConnectorResponse(UncheckedBaseModel):
connector: Connector
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/get_model_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .compatible_endpoint import CompatibleEndpoint
from .get_model_response_sampling_defaults import GetModelResponseSamplingDefaults
class GetModelResponse(UncheckedBaseModel):
"""
Contains information about the model and which API endpoints it can be used with.
"""
name: typing.Optional[str] = pydantic.Field(default=None)
"""
Specify this name in the `model` parameter of API requests to use your chosen model.
"""
is_deprecated: typing.Optional[bool] = pydantic.Field(default=None)
"""
Whether the model is deprecated or not.
"""
endpoints: typing.Optional[typing.List[CompatibleEndpoint]] = pydantic.Field(default=None)
"""
The API endpoints that the model is compatible with.
"""
finetuned: typing.Optional[bool] = pydantic.Field(default=None)
"""
Whether the model has been fine-tuned or not.
"""
context_length: typing.Optional[float] = pydantic.Field(default=None)
"""
The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default.
"""
tokenizer_url: typing.Optional[str] = pydantic.Field(default=None)
"""
Public URL to the tokenizer's configuration file.
"""
default_endpoints: typing.Optional[typing.List[CompatibleEndpoint]] = pydantic.Field(default=None)
"""
The API endpoints that the model is default to.
"""
features: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
The features that the model supports.
"""
sampling_defaults: typing.Optional[GetModelResponseSamplingDefaults] = pydantic.Field(default=None)
"""
Default sampling parameters for this model when omitted from API requests.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/get_model_response_sampling_defaults.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class GetModelResponseSamplingDefaults(UncheckedBaseModel):
"""
Default sampling parameters for this model when omitted from API requests.
"""
temperature: typing.Optional[float] = None
k: typing.Optional[int] = None
p: typing.Optional[float] = None
frequency_penalty: typing.Optional[float] = None
presence_penalty: typing.Optional[float] = None
max_tokens_per_doc: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/image.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class Image(UncheckedBaseModel):
width: int = pydantic.Field()
"""
Width of the image in pixels
"""
height: int = pydantic.Field()
"""
Height of the image in pixels
"""
format: str = pydantic.Field()
"""
Format of the image
"""
bit_depth: int = pydantic.Field()
"""
Bit depth of the image
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/image_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .image_url import ImageUrl
class ImageContent(UncheckedBaseModel):
"""
Image content of the message.
"""
image_url: ImageUrl
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/image_url.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .image_url_detail import ImageUrlDetail
class ImageUrl(UncheckedBaseModel):
url: str = pydantic.Field()
"""
URL of an image. Can be either a base64 data URI or a web URL.
"""
detail: typing.Optional[ImageUrlDetail] = pydantic.Field(default=None)
"""
Controls the level of detail in image processing. `"auto"` is the default and lets the system choose, `"low"` is faster but less detailed, and `"high"` preserves maximum detail. You can save tokens and speed up responses by using detail: `"low"`.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/image_url_detail.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ImageUrlDetail = typing.Union[typing.Literal["auto", "low", "high"], typing.Any]
================================================
FILE: src/cohere/types/json_response_format.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.serialization import FieldMetadata
from ..core.unchecked_base_model import UncheckedBaseModel
class JsonResponseFormat(UncheckedBaseModel):
schema_: typing_extensions.Annotated[
typing.Optional[typing.Dict[str, typing.Any]],
FieldMetadata(alias="schema"),
pydantic.Field(
alias="schema",
description='A JSON schema object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information.\nExample (required name and age object):\n```json\n{\n "type": "object",\n "properties": {\n "name": {"type": "string"},\n "age": {"type": "integer"}\n },\n "required": ["name", "age"]\n}\n```\n\n**Note**: This field must not be specified when the `type` is set to `"text"`.',
),
] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/json_response_format_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class JsonResponseFormatV2(UncheckedBaseModel):
json_schema: typing.Optional[typing.Dict[str, typing.Any]] = pydantic.Field(default=None)
"""
A [JSON schema](https://json-schema.org/overview/what-is-jsonschema) object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information.
Example (required name and age object):
```json
{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}
```
**Note**: This field must not be specified when the `type` is set to `"text"`.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/label_metric.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class LabelMetric(UncheckedBaseModel):
total_examples: typing.Optional[int] = pydantic.Field(default=None)
"""
Total number of examples for this label
"""
label: typing.Optional[str] = pydantic.Field(default=None)
"""
value of the label
"""
samples: typing.Optional[typing.List[str]] = pydantic.Field(default=None)
"""
samples for this label
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/list_connectors_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .connector import Connector
class ListConnectorsResponse(UncheckedBaseModel):
connectors: typing.List[Connector]
total_count: typing.Optional[float] = pydantic.Field(default=None)
"""
Total number of connectors.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/list_embed_job_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .embed_job import EmbedJob
class ListEmbedJobResponse(UncheckedBaseModel):
embed_jobs: typing.Optional[typing.List[EmbedJob]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/list_models_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .get_model_response import GetModelResponse
class ListModelsResponse(UncheckedBaseModel):
models: typing.List[GetModelResponse]
next_page_token: typing.Optional[str] = pydantic.Field(default=None)
"""
A token to retrieve the next page of results. Provide in the page_token parameter of the next request.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/logprob_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class LogprobItem(UncheckedBaseModel):
text: typing.Optional[str] = pydantic.Field(default=None)
"""
The text chunk for which the log probabilities was calculated.
"""
token_ids: typing.List[int] = pydantic.Field()
"""
The token ids of each token used to construct the text chunk.
"""
logprobs: typing.Optional[typing.List[float]] = pydantic.Field(default=None)
"""
The log probability of each token used to construct the text chunk.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/message.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .tool_call import ToolCall
from .tool_result import ToolResult
class ChatbotMessage(UncheckedBaseModel):
role: typing.Literal["CHATBOT"] = "CHATBOT"
message: str
tool_calls: typing.Optional[typing.List[ToolCall]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class SystemMessage(UncheckedBaseModel):
role: typing.Literal["SYSTEM"] = "SYSTEM"
message: str
tool_calls: typing.Optional[typing.List[ToolCall]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class UserMessage(UncheckedBaseModel):
role: typing.Literal["USER"] = "USER"
message: str
tool_calls: typing.Optional[typing.List[ToolCall]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolMessage(UncheckedBaseModel):
role: typing.Literal["TOOL"] = "TOOL"
tool_results: typing.Optional[typing.List[ToolResult]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
Message = typing_extensions.Annotated[
typing.Union[ChatbotMessage, SystemMessage, UserMessage, ToolMessage], UnionMetadata(discriminant="role")
]
================================================
FILE: src/cohere/types/metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .finetune_dataset_metrics import FinetuneDatasetMetrics
class Metrics(UncheckedBaseModel):
finetune_dataset_metrics: typing.Optional[FinetuneDatasetMetrics] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/non_streamed_chat_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .chat_citation import ChatCitation
from .chat_document import ChatDocument
from .chat_search_query import ChatSearchQuery
from .chat_search_result import ChatSearchResult
from .finish_reason import FinishReason
from .message import Message
from .tool_call import ToolCall
class NonStreamedChatResponse(UncheckedBaseModel):
text: str = pydantic.Field()
"""
Contents of the reply generated by the model.
"""
generation_id: typing.Optional[str] = pydantic.Field(default=None)
"""
Unique identifier for the generated reply. Useful for submitting feedback.
"""
response_id: typing.Optional[str] = pydantic.Field(default=None)
"""
Unique identifier for the response.
"""
citations: typing.Optional[typing.List[ChatCitation]] = pydantic.Field(default=None)
"""
Inline citations for the generated reply.
"""
documents: typing.Optional[typing.List[ChatDocument]] = pydantic.Field(default=None)
"""
Documents seen by the model when generating the reply.
"""
is_search_required: typing.Optional[bool] = pydantic.Field(default=None)
"""
Denotes that a search for documents is required during the RAG flow.
"""
search_queries: typing.Optional[typing.List[ChatSearchQuery]] = pydantic.Field(default=None)
"""
Generated search queries, meant to be used as part of the RAG flow.
"""
search_results: typing.Optional[typing.List[ChatSearchResult]] = pydantic.Field(default=None)
"""
Documents retrieved from each of the conducted searches.
"""
finish_reason: typing.Optional[FinishReason] = None
tool_calls: typing.Optional[typing.List[ToolCall]] = None
chat_history: typing.Optional[typing.List[Message]] = pydantic.Field(default=None)
"""
A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`.
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/o_auth_authorize_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class OAuthAuthorizeResponse(UncheckedBaseModel):
redirect_url: typing.Optional[str] = pydantic.Field(default=None)
"""
The OAuth 2.0 redirect url. Redirect the user to this url to authorize the connector.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/parse_info.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ParseInfo(UncheckedBaseModel):
separator: typing.Optional[str] = None
delimiter: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/rerank_document.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
RerankDocument = typing.Dict[str, str]
================================================
FILE: src/cohere/types/rerank_request_documents_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .rerank_document import RerankDocument
RerankRequestDocumentsItem = typing.Union[str, RerankDocument]
================================================
FILE: src/cohere/types/rerank_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
from .rerank_response_results_item import RerankResponseResultsItem
class RerankResponse(UncheckedBaseModel):
id: typing.Optional[str] = None
results: typing.List[RerankResponseResultsItem] = pydantic.Field()
"""
An ordered list of ranked documents
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/rerank_response_results_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .rerank_response_results_item_document import RerankResponseResultsItemDocument
class RerankResponseResultsItem(UncheckedBaseModel):
document: typing.Optional[RerankResponseResultsItemDocument] = pydantic.Field(default=None)
"""
If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in
"""
index: int = pydantic.Field()
"""
Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance)
"""
relevance_score: float = pydantic.Field()
"""
Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/rerank_response_results_item_document.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class RerankResponseResultsItemDocument(UncheckedBaseModel):
"""
If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in
"""
text: str = pydantic.Field()
"""
The text of the document to rerank
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/reranker_data_metrics.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class RerankerDataMetrics(UncheckedBaseModel):
num_train_queries: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of training queries.
"""
num_train_relevant_passages: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all relevant passages of valid training examples.
"""
num_train_hard_negatives: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all hard negatives of valid training examples.
"""
num_eval_queries: typing.Optional[int] = pydantic.Field(default=None)
"""
The number of evaluation queries.
"""
num_eval_relevant_passages: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all relevant passages of valid eval examples.
"""
num_eval_hard_negatives: typing.Optional[int] = pydantic.Field(default=None)
"""
The sum of all hard negatives of valid eval examples.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/response_format.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.serialization import FieldMetadata
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
class TextResponseFormat(UncheckedBaseModel):
"""
Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models.
The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`.
A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure.
**Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length.
**Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided).
"""
type: typing.Literal["text"] = "text"
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class JsonObjectResponseFormat(UncheckedBaseModel):
"""
Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models.
The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`.
A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure.
**Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length.
**Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided).
"""
type: typing.Literal["json_object"] = "json_object"
schema_: typing_extensions.Annotated[
typing.Optional[typing.Dict[str, typing.Any]], FieldMetadata(alias="schema"), pydantic.Field(alias="schema")
] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
ResponseFormat = typing_extensions.Annotated[
typing.Union[TextResponseFormat, JsonObjectResponseFormat], UnionMetadata(discriminant="type")
]
================================================
FILE: src/cohere/types/response_format_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
class TextResponseFormatV2(UncheckedBaseModel):
"""
Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models.
The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`.
A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure.
**Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length.
**Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting.
**Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters.
"""
type: typing.Literal["text"] = "text"
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class JsonObjectResponseFormatV2(UncheckedBaseModel):
"""
Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models.
The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`.
A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure.
**Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length.
**Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting.
**Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters.
"""
type: typing.Literal["json_object"] = "json_object"
json_schema: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
ResponseFormatV2 = typing_extensions.Annotated[
typing.Union[TextResponseFormatV2, JsonObjectResponseFormatV2], UnionMetadata(discriminant="type")
]
================================================
FILE: src/cohere/types/single_generation.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .single_generation_token_likelihoods_item import SingleGenerationTokenLikelihoodsItem
class SingleGeneration(UncheckedBaseModel):
id: str
text: str
index: typing.Optional[int] = pydantic.Field(default=None)
"""
Refers to the nth generation. Only present when `num_generations` is greater than zero.
"""
likelihood: typing.Optional[float] = None
token_likelihoods: typing.Optional[typing.List[SingleGenerationTokenLikelihoodsItem]] = pydantic.Field(default=None)
"""
Only returned if `return_likelihoods` is set to `GENERATION` or `ALL`. The likelihood refers to the average log-likelihood of the entire specified string, which is useful for [evaluating the performance of your model](likelihood-eval), especially if you've created a [custom model](https://docs.cohere.com/docs/training-custom-models). Individual token likelihoods provide the log-likelihood of each token. The first token will not have a likelihood.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/single_generation_in_stream.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .finish_reason import FinishReason
class SingleGenerationInStream(UncheckedBaseModel):
id: str
text: str = pydantic.Field()
"""
Full text of the generation.
"""
index: typing.Optional[int] = pydantic.Field(default=None)
"""
Refers to the nth generation. Only present when `num_generations` is greater than zero.
"""
finish_reason: FinishReason
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/single_generation_token_likelihoods_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class SingleGenerationTokenLikelihoodsItem(UncheckedBaseModel):
token: str
likelihood: float
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/source.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
class ToolSource(UncheckedBaseModel):
"""
A source object containing information about the source of the data cited.
"""
type: typing.Literal["tool"] = "tool"
id: typing.Optional[str] = None
tool_output: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class DocumentSource(UncheckedBaseModel):
"""
A source object containing information about the source of the data cited.
"""
type: typing.Literal["document"] = "document"
id: typing.Optional[str] = None
document: typing.Optional[typing.Dict[str, typing.Any]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
Source = typing_extensions.Annotated[typing.Union[ToolSource, DocumentSource], UnionMetadata(discriminant="type")]
================================================
FILE: src/cohere/types/streamed_chat_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .chat_citation import ChatCitation
from .chat_document import ChatDocument
from .chat_search_query import ChatSearchQuery
from .chat_search_result import ChatSearchResult
from .chat_stream_end_event_finish_reason import ChatStreamEndEventFinishReason
from .non_streamed_chat_response import NonStreamedChatResponse
from .tool_call import ToolCall
from .tool_call_delta import ToolCallDelta
class StreamStartStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["stream-start"] = "stream-start"
generation_id: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class SearchQueriesGenerationStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["search-queries-generation"] = "search-queries-generation"
search_queries: typing.List[ChatSearchQuery]
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class SearchResultsStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["search-results"] = "search-results"
search_results: typing.Optional[typing.List[ChatSearchResult]] = None
documents: typing.Optional[typing.List[ChatDocument]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class TextGenerationStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["text-generation"] = "text-generation"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class CitationGenerationStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["citation-generation"] = "citation-generation"
citations: typing.List[ChatCitation]
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolCallsGenerationStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["tool-calls-generation"] = "tool-calls-generation"
text: typing.Optional[str] = None
tool_calls: typing.List[ToolCall]
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class StreamEndStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["stream-end"] = "stream-end"
finish_reason: ChatStreamEndEventFinishReason
response: NonStreamedChatResponse
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolCallsChunkStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["tool-calls-chunk"] = "tool-calls-chunk"
tool_call_delta: ToolCallDelta
text: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class DebugStreamedChatResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
event_type: typing.Literal["debug"] = "debug"
prompt: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
StreamedChatResponse = typing_extensions.Annotated[
typing.Union[
StreamStartStreamedChatResponse,
SearchQueriesGenerationStreamedChatResponse,
SearchResultsStreamedChatResponse,
TextGenerationStreamedChatResponse,
CitationGenerationStreamedChatResponse,
ToolCallsGenerationStreamedChatResponse,
StreamEndStreamedChatResponse,
ToolCallsChunkStreamedChatResponse,
DebugStreamedChatResponse,
],
UnionMetadata(discriminant="event_type"),
]
================================================
FILE: src/cohere/types/summarize_request_extractiveness.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
SummarizeRequestExtractiveness = typing.Union[typing.Literal["low", "medium", "high"], typing.Any]
================================================
FILE: src/cohere/types/summarize_request_format.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
SummarizeRequestFormat = typing.Union[typing.Literal["paragraph", "bullets"], typing.Any]
================================================
FILE: src/cohere/types/summarize_request_length.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
SummarizeRequestLength = typing.Union[typing.Literal["short", "medium", "long"], typing.Any]
================================================
FILE: src/cohere/types/summarize_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
class SummarizeResponse(UncheckedBaseModel):
id: typing.Optional[str] = pydantic.Field(default=None)
"""
Generated ID for the summary
"""
summary: typing.Optional[str] = pydantic.Field(default=None)
"""
Generated summary for the text
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/system_message_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .system_message_v2content import SystemMessageV2Content
class SystemMessageV2(UncheckedBaseModel):
"""
A message from the system.
"""
content: SystemMessageV2Content
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/system_message_v2content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .system_message_v2content_one_item import SystemMessageV2ContentOneItem
SystemMessageV2Content = typing.Union[str, typing.List[SystemMessageV2ContentOneItem]]
================================================
FILE: src/cohere/types/system_message_v2content_one_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class TextSystemMessageV2ContentOneItem(UncheckedBaseModel):
type: typing.Literal["text"] = "text"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
SystemMessageV2ContentOneItem = TextSystemMessageV2ContentOneItem
================================================
FILE: src/cohere/types/thinking.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .thinking_type import ThinkingType
class Thinking(UncheckedBaseModel):
"""
Configuration for [reasoning features](https://docs.cohere.com/docs/reasoning).
"""
type: ThinkingType = pydantic.Field()
"""
Reasoning is enabled by default for models that support it, but can be turned off by setting `"type": "disabled"`.
"""
token_budget: typing.Optional[int] = pydantic.Field(default=None)
"""
The maximum number of tokens the model can use for thinking, which must be set to a positive integer.
The model will stop thinking if it reaches the thinking token budget and will proceed with the response.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/thinking_type.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
ThinkingType = typing.Union[typing.Literal["enabled", "disabled"], typing.Any]
================================================
FILE: src/cohere/types/tokenize_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .api_meta import ApiMeta
class TokenizeResponse(UncheckedBaseModel):
tokens: typing.List[int] = pydantic.Field()
"""
An array of tokens, where each token is an integer.
"""
token_strings: typing.List[str]
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_parameter_definitions_value import ToolParameterDefinitionsValue
class Tool(UncheckedBaseModel):
name: str = pydantic.Field()
"""
The name of the tool to be called. Valid names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit.
"""
description: str = pydantic.Field()
"""
The description of what the tool does, the model uses the description to choose when and how to call the function.
"""
parameter_definitions: typing.Optional[typing.Dict[str, ToolParameterDefinitionsValue]] = pydantic.Field(
default=None
)
"""
The input parameters of the tool. Accepts a dictionary where the key is the name of the parameter and the value is the parameter spec. Valid parameter names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit.
```
{
"my_param": {
"description": ,
"type": , // any python data type, such as 'str', 'bool'
"required":
}
}
```
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_call.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ToolCall(UncheckedBaseModel):
"""
Contains the tool calls generated by the model. Use it to invoke your tools.
"""
name: str = pydantic.Field()
"""
Name of the tool to call.
"""
parameters: typing.Dict[str, typing.Any] = pydantic.Field()
"""
The name and value of the parameters to use when invoking a tool.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_call_delta.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ToolCallDelta(UncheckedBaseModel):
"""
Contains the chunk of the tool call generation in the stream.
"""
name: typing.Optional[str] = pydantic.Field(default=None)
"""
Name of the tool call
"""
index: typing.Optional[float] = pydantic.Field(default=None)
"""
Index of the tool call generated
"""
parameters: typing.Optional[str] = pydantic.Field(default=None)
"""
Chunk of the tool parameters
"""
text: typing.Optional[str] = pydantic.Field(default=None)
"""
Chunk of the tool plan text
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_call_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_call_v2function import ToolCallV2Function
class ToolCallV2(UncheckedBaseModel):
"""
An array of tool calls to be made.
"""
id: str
type: typing.Literal["function"] = "function"
function: typing.Optional[ToolCallV2Function] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_call_v2function.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ToolCallV2Function(UncheckedBaseModel):
name: typing.Optional[str] = None
arguments: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_content.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from .document import Document
class TextToolContent(UncheckedBaseModel):
"""
A content block which contains information about the content of a tool result
"""
type: typing.Literal["text"] = "text"
text: str
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class DocumentToolContent(UncheckedBaseModel):
"""
A content block which contains information about the content of a tool result
"""
type: typing.Literal["document"] = "document"
document: Document
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
ToolContent = typing_extensions.Annotated[
typing.Union[TextToolContent, DocumentToolContent], UnionMetadata(discriminant="type")
]
================================================
FILE: src/cohere/types/tool_message_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_message_v2content import ToolMessageV2Content
class ToolMessageV2(UncheckedBaseModel):
"""
A message with Tool outputs.
"""
tool_call_id: str = pydantic.Field()
"""
The id of the associated tool call that has provided the given content
"""
content: ToolMessageV2Content = pydantic.Field()
"""
Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_message_v2content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .tool_content import ToolContent
ToolMessageV2Content = typing.Union[str, typing.List[ToolContent]]
================================================
FILE: src/cohere/types/tool_parameter_definitions_value.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ToolParameterDefinitionsValue(UncheckedBaseModel):
description: typing.Optional[str] = pydantic.Field(default=None)
"""
The description of the parameter.
"""
type: str = pydantic.Field()
"""
The type of the parameter. Must be a valid Python type.
"""
required: typing.Optional[bool] = pydantic.Field(default=None)
"""
Denotes whether the parameter is always present (required) or not. Defaults to not required.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_result.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_call import ToolCall
class ToolResult(UncheckedBaseModel):
call: ToolCall
outputs: typing.List[typing.Dict[str, typing.Any]]
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .tool_v2function import ToolV2Function
class ToolV2(UncheckedBaseModel):
type: typing.Literal["function"] = "function"
function: typing.Optional[ToolV2Function] = pydantic.Field(default=None)
"""
The function to be executed.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/tool_v2function.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class ToolV2Function(UncheckedBaseModel):
"""
The function to be executed.
"""
name: str = pydantic.Field()
"""
The name of the function.
"""
description: typing.Optional[str] = pydantic.Field(default=None)
"""
The description of the function.
"""
parameters: typing.Dict[str, typing.Any] = pydantic.Field()
"""
The parameters of the function as a JSON schema.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/update_connector_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .connector import Connector
class UpdateConnectorResponse(UncheckedBaseModel):
connector: Connector
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/usage.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .usage_billed_units import UsageBilledUnits
from .usage_tokens import UsageTokens
class Usage(UncheckedBaseModel):
billed_units: typing.Optional[UsageBilledUnits] = None
tokens: typing.Optional[UsageTokens] = None
cached_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of prompt tokens that hit the inference cache.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/usage_billed_units.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class UsageBilledUnits(UncheckedBaseModel):
input_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed input tokens.
"""
output_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed output tokens.
"""
search_units: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed search units.
"""
classifications: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of billed classifications units.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/usage_tokens.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
class UsageTokens(UncheckedBaseModel):
input_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of tokens used as input to the model.
"""
output_tokens: typing.Optional[float] = pydantic.Field(default=None)
"""
The number of tokens produced by the model.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/user_message_v2.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ..core.pydantic_utilities import IS_PYDANTIC_V2
from ..core.unchecked_base_model import UncheckedBaseModel
from .user_message_v2content import UserMessageV2Content
class UserMessageV2(UncheckedBaseModel):
"""
A message from the user.
"""
content: UserMessageV2Content = pydantic.Field()
"""
The content of the message. This can be a string or a list of content blocks.
If a string is provided, it will be treated as a text content block.
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/types/user_message_v2content.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from .content import Content
UserMessageV2Content = typing.Union[str, typing.List[Content]]
================================================
FILE: src/cohere/utils.py
================================================
import asyncio
import csv
import json
import time
import typing
from typing import Optional
import requests
from fastavro import parse_schema, reader, writer
from . import EmbedResponse, EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse, ApiMeta, \
EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset
from .datasets import DatasetsCreateResponse, DatasetsGetResponse
from .overrides import get_fields
# Note: utils.py does NOT call run_overrides() itself - that's done in client.py
# which imports utils.py. This ensures overrides are applied when client is used.
def get_terminal_states():
return get_success_states() | get_failed_states()
def get_success_states():
return {"complete", "validated"}
def get_failed_states():
return {"unknown", "failed", "skipped", "cancelled", "failed"}
def get_id(
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]):
return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr(
getattr(awaitable, "dataset", None), "id", None)
def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]):
return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None)
def get_job(cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
typing.Union[
EmbedJob, DatasetsGetResponse]:
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
return cohere.embed_jobs.get(id=get_id(awaitable))
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
return cohere.datasets.get(id=get_id(awaitable))
else:
raise ValueError(f"Unexpected awaitable type {awaitable}")
async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \
typing.Union[
EmbedJob, DatasetsGetResponse]:
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
return await cohere.embed_jobs.get(id=get_id(awaitable))
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
return await cohere.datasets.get(id=get_id(awaitable))
else:
raise ValueError(f"Unexpected awaitable type {awaitable}")
def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]:
if isinstance(job, EmbedJob):
return f"Embed job {job.job_id} failed with status {job.status}"
elif isinstance(job, DatasetsGetResponse):
return f"Dataset creation failed with status {job.dataset.validation_status} and error : {job.dataset.validation_error}"
return None
@typing.overload
def wait(
cohere: typing.Any,
awaitable: CreateEmbedJobResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> EmbedJob:
...
@typing.overload
def wait(
cohere: typing.Any,
awaitable: DatasetsCreateResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> DatasetsGetResponse:
...
def wait(
cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
timeout: Optional[float] = None,
interval: float = 2,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
start_time = time.time()
terminal_states = get_terminal_states()
failed_states = get_failed_states()
job = get_job(cohere, awaitable)
while get_validation_status(job) not in terminal_states:
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError(f"wait timed out after {timeout} seconds")
time.sleep(interval)
print("...")
job = get_job(cohere, awaitable)
if get_validation_status(job) in failed_states:
raise Exception(get_failure_reason(job))
return job
@typing.overload
async def async_wait(
cohere: typing.Any,
awaitable: CreateEmbedJobResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> EmbedJob:
...
@typing.overload
async def async_wait(
cohere: typing.Any,
awaitable: DatasetsCreateResponse,
timeout: Optional[float] = None,
interval: float = 10,
) -> DatasetsGetResponse:
...
async def async_wait(
cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
timeout: Optional[float] = None,
interval: float = 10,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
start_time = time.time()
terminal_states = get_terminal_states()
failed_states = get_failed_states()
job = await async_get_job(cohere, awaitable)
while get_validation_status(job) not in terminal_states:
if timeout is not None and time.time() - start_time > timeout:
raise TimeoutError(f"wait timed out after {timeout} seconds")
await asyncio.sleep(interval)
print("...")
job = await async_get_job(cohere, awaitable)
if get_validation_status(job) in failed_states:
raise Exception(get_failure_reason(job))
return job
def sum_fields_if_not_none(obj: typing.Any, field: str) -> Optional[int]:
non_none = [getattr(obj, field) for obj in obj if getattr(obj, field) is not None]
return sum(non_none) if non_none else None
def merge_meta_field(metas: typing.List[ApiMeta]) -> ApiMeta:
api_version = metas[0].api_version if metas else None
billed_units = [meta.billed_units for meta in metas]
input_tokens = sum_fields_if_not_none(billed_units, "input_tokens")
output_tokens = sum_fields_if_not_none(billed_units, "output_tokens")
search_units = sum_fields_if_not_none(billed_units, "search_units")
classifications = sum_fields_if_not_none(billed_units, "classifications")
warnings = {warning for meta in metas if meta.warnings for warning in meta.warnings}
return ApiMeta(
api_version=api_version,
billed_units=ApiMetaBilledUnits(
input_tokens=input_tokens,
output_tokens=output_tokens,
search_units=search_units,
classifications=classifications
),
warnings=list(warnings)
)
def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedResponse:
meta = merge_meta_field([response.meta for response in responses if response.meta])
response_id = ", ".join(response.id for response in responses)
texts = [
text
for response in responses
if response.texts is not None
for text in response.texts
]
if responses[0].response_type == "embeddings_floats":
embeddings_floats = typing.cast(typing.List[EmbeddingsFloatsEmbedResponse], responses)
embeddings = [
embedding
for embeddings_floats in embeddings_floats
for embedding in embeddings_floats.embeddings
]
return EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id=response_id,
texts=texts,
embeddings=embeddings,
meta=meta
)
else:
embeddings_type = typing.cast(typing.List[EmbeddingsByTypeEmbedResponse], responses)
embeddings_by_type = [
response.embeddings
for response in embeddings_type
]
# only get set keys from the pydantic model (i.e. exclude fields that are set to 'None')
fields = [x for x in get_fields(embeddings_type[0].embeddings) if getattr(embeddings_type[0].embeddings, x) is not None]
merged_dicts = {
field: [
embedding
for embedding_by_type in embeddings_by_type
for embedding in getattr(embedding_by_type, field)
]
for field in fields
}
embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts)
return EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id=response_id,
embeddings=embeddings_by_type_merged,
texts=texts,
meta=meta
)
supported_formats = ["jsonl", "csv", "avro"]
def save_avro(dataset: Dataset, filepath: str):
if not dataset.schema_:
raise ValueError("Dataset does not have a schema")
schema = parse_schema(json.loads(dataset.schema_))
with open(filepath, "wb") as outfile:
writer(outfile, schema, dataset_generator(dataset))
def save_jsonl(dataset: Dataset, filepath: str):
with open(filepath, "w") as outfile:
for data in dataset_generator(dataset):
json.dump(data, outfile)
outfile.write("\n")
def save_csv(dataset: Dataset, filepath: str):
with open(filepath, "w") as outfile:
for i, data in enumerate(dataset_generator(dataset)):
if i == 0:
writer = csv.DictWriter(outfile, fieldnames=list(data.keys()))
writer.writeheader()
writer.writerow(data)
def dataset_generator(dataset: Dataset):
if not dataset.dataset_parts:
raise ValueError("Dataset does not have dataset_parts")
for part in dataset.dataset_parts:
if not part.url:
raise ValueError("Dataset part does not have a url")
resp = requests.get(part.url, stream=True)
for record in reader(resp.raw): # type: ignore
yield record
class SdkUtils:
@staticmethod
def save_dataset(dataset: Dataset, filepath: str, format: typing.Literal["jsonl", "csv", "avro"] = "jsonl"):
if format == "jsonl":
return save_jsonl(dataset, filepath)
if format == "csv":
return save_csv(dataset, filepath)
if format == "avro":
return save_avro(dataset, filepath)
raise Exception(f"unsupported format must be one of : {supported_formats}")
class SyncSdkUtils(SdkUtils):
pass
class AsyncSdkUtils(SdkUtils):
pass
================================================
FILE: src/cohere/v2/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .types import (
CitationEndV2ChatStreamResponse,
CitationStartV2ChatStreamResponse,
ContentDeltaV2ChatStreamResponse,
ContentEndV2ChatStreamResponse,
ContentStartV2ChatStreamResponse,
DebugV2ChatStreamResponse,
MessageEndV2ChatStreamResponse,
MessageStartV2ChatStreamResponse,
ToolCallDeltaV2ChatStreamResponse,
ToolCallEndV2ChatStreamResponse,
ToolCallStartV2ChatStreamResponse,
ToolPlanDeltaV2ChatStreamResponse,
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
V2ChatRequestToolChoice,
V2ChatResponse,
V2ChatStreamRequestDocumentsItem,
V2ChatStreamRequestSafetyMode,
V2ChatStreamRequestToolChoice,
V2ChatStreamResponse,
V2EmbedRequestTruncate,
V2RerankResponse,
V2RerankResponseResultsItem,
)
_dynamic_imports: typing.Dict[str, str] = {
"CitationEndV2ChatStreamResponse": ".types",
"CitationStartV2ChatStreamResponse": ".types",
"ContentDeltaV2ChatStreamResponse": ".types",
"ContentEndV2ChatStreamResponse": ".types",
"ContentStartV2ChatStreamResponse": ".types",
"DebugV2ChatStreamResponse": ".types",
"MessageEndV2ChatStreamResponse": ".types",
"MessageStartV2ChatStreamResponse": ".types",
"ToolCallDeltaV2ChatStreamResponse": ".types",
"ToolCallEndV2ChatStreamResponse": ".types",
"ToolCallStartV2ChatStreamResponse": ".types",
"ToolPlanDeltaV2ChatStreamResponse": ".types",
"V2ChatRequestDocumentsItem": ".types",
"V2ChatRequestSafetyMode": ".types",
"V2ChatRequestToolChoice": ".types",
"V2ChatResponse": ".types",
"V2ChatStreamRequestDocumentsItem": ".types",
"V2ChatStreamRequestSafetyMode": ".types",
"V2ChatStreamRequestToolChoice": ".types",
"V2ChatStreamResponse": ".types",
"V2EmbedRequestTruncate": ".types",
"V2RerankResponse": ".types",
"V2RerankResponseResultsItem": ".types",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"CitationEndV2ChatStreamResponse",
"CitationStartV2ChatStreamResponse",
"ContentDeltaV2ChatStreamResponse",
"ContentEndV2ChatStreamResponse",
"ContentStartV2ChatStreamResponse",
"DebugV2ChatStreamResponse",
"MessageEndV2ChatStreamResponse",
"MessageStartV2ChatStreamResponse",
"ToolCallDeltaV2ChatStreamResponse",
"ToolCallEndV2ChatStreamResponse",
"ToolCallStartV2ChatStreamResponse",
"ToolPlanDeltaV2ChatStreamResponse",
"V2ChatRequestDocumentsItem",
"V2ChatRequestSafetyMode",
"V2ChatRequestToolChoice",
"V2ChatResponse",
"V2ChatStreamRequestDocumentsItem",
"V2ChatStreamRequestSafetyMode",
"V2ChatStreamRequestToolChoice",
"V2ChatStreamResponse",
"V2EmbedRequestTruncate",
"V2RerankResponse",
"V2RerankResponseResultsItem",
]
================================================
FILE: src/cohere/v2/client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from ..types.chat_messages import ChatMessages
from ..types.citation_options import CitationOptions
from ..types.embed_by_type_response import EmbedByTypeResponse
from ..types.embed_input import EmbedInput
from ..types.embed_input_type import EmbedInputType
from ..types.embedding_type import EmbeddingType
from ..types.response_format_v2 import ResponseFormatV2
from ..types.thinking import Thinking
from ..types.tool_v2 import ToolV2
from .raw_client import AsyncRawV2Client, RawV2Client
from .types.v2chat_request_documents_item import V2ChatRequestDocumentsItem
from .types.v2chat_request_safety_mode import V2ChatRequestSafetyMode
from .types.v2chat_request_tool_choice import V2ChatRequestToolChoice
from .types.v2chat_response import V2ChatResponse
from .types.v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem
from .types.v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode
from .types.v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice
from .types.v2chat_stream_response import V2ChatStreamResponse
from .types.v2embed_request_truncate import V2EmbedRequestTruncate
from .types.v2rerank_response import V2RerankResponse
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class V2Client:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._raw_client = RawV2Client(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> RawV2Client:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
RawV2Client
"""
return self._raw_client
def chat_stream(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[V2ChatStreamResponse]:
"""
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[V2ChatStreamResponse]
Examples
--------
from cohere import Client, UserChatMessageV2
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
response = client.v2.chat_stream(
model="command-a-03-2025",
messages=[
UserChatMessageV2(
content="Tell me about LLMs",
)
],
)
for chunk in response:
yield chunk
"""
with self._raw_client.chat_stream(
model=model,
messages=messages,
tools=tools,
strict_tools=strict_tools,
documents=documents,
citation_options=citation_options,
response_format=response_format,
safety_mode=safety_mode,
max_tokens=max_tokens,
stop_sequences=stop_sequences,
temperature=temperature,
seed=seed,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
k=k,
p=p,
logprobs=logprobs,
tool_choice=tool_choice,
thinking=thinking,
priority=priority,
request_options=request_options,
) as r:
yield from r.data
def chat(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> V2ChatResponse:
"""
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
V2ChatResponse
Examples
--------
from cohere import Client, UserChatMessageV2
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.v2.chat(
model="command-a-03-2025",
messages=[
UserChatMessageV2(
content="Tell me about LLMs",
)
],
)
"""
_response = self._raw_client.chat(
model=model,
messages=messages,
tools=tools,
strict_tools=strict_tools,
documents=documents,
citation_options=citation_options,
response_format=response_format,
safety_mode=safety_mode,
max_tokens=max_tokens,
stop_sequences=stop_sequences,
temperature=temperature,
seed=seed,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
k=k,
p=p,
logprobs=logprobs,
tool_choice=tool_choice,
thinking=thinking,
priority=priority,
request_options=request_options,
)
return _response.data
def embed(
self,
*,
model: str,
input_type: EmbedInputType,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
output_dimension: typing.Optional[int] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> EmbedByTypeResponse:
"""
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
model : str
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : EmbedInputType
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Image embeddings are supported with Embed v3.0 and newer models.
inputs : typing.Optional[typing.Sequence[EmbedInput]]
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
max_tokens : typing.Optional[int]
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
output_dimension : typing.Optional[int]
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[V2EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedByTypeResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.v2.embed(
texts=["hello", "goodbye"],
model="embed-v4.0",
input_type="classification",
embedding_types=["float"],
)
"""
_response = self._raw_client.embed(
model=model,
input_type=input_type,
texts=texts,
images=images,
inputs=inputs,
max_tokens=max_tokens,
output_dimension=output_dimension,
embedding_types=embedding_types,
truncate=truncate,
priority=priority,
request_options=request_options,
)
return _response.data
def rerank(
self,
*,
model: str,
query: str,
documents: typing.Sequence[str],
top_n: typing.Optional[int] = OMIT,
max_tokens_per_doc: typing.Optional[int] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> V2RerankResponse:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
query : str
The search query
documents : typing.Sequence[str]
A list of texts that will be compared to the `query`.
For optimal performance we recommend against sending more than 1,000 documents in a single request.
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
**Note**: structured data should be formatted as YAML strings for best performance.
top_n : typing.Optional[int]
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
max_tokens_per_doc : typing.Optional[int]
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
V2RerankResponse
OK
Examples
--------
from cohere import Client
client = Client(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
client.v2.rerank(
documents=[
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
"""
_response = self._raw_client.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
max_tokens_per_doc=max_tokens_per_doc,
priority=priority,
request_options=request_options,
)
return _response.data
class AsyncV2Client:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._raw_client = AsyncRawV2Client(client_wrapper=client_wrapper)
@property
def with_raw_response(self) -> AsyncRawV2Client:
"""
Retrieves a raw implementation of this client that returns raw responses.
Returns
-------
AsyncRawV2Client
"""
return self._raw_client
async def chat_stream(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[V2ChatStreamResponse]:
"""
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[V2ChatStreamResponse]
Examples
--------
import asyncio
from cohere import AsyncClient, UserChatMessageV2
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
response = await client.v2.chat_stream(
model="command-a-03-2025",
messages=[
UserChatMessageV2(
content="Tell me about LLMs",
)
],
)
async for chunk in response:
yield chunk
asyncio.run(main())
"""
async with self._raw_client.chat_stream(
model=model,
messages=messages,
tools=tools,
strict_tools=strict_tools,
documents=documents,
citation_options=citation_options,
response_format=response_format,
safety_mode=safety_mode,
max_tokens=max_tokens,
stop_sequences=stop_sequences,
temperature=temperature,
seed=seed,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
k=k,
p=p,
logprobs=logprobs,
tool_choice=tool_choice,
thinking=thinking,
priority=priority,
request_options=request_options,
) as r:
async for _chunk in r.data:
yield _chunk
async def chat(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> V2ChatResponse:
"""
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
V2ChatResponse
Examples
--------
import asyncio
from cohere import AsyncClient, UserChatMessageV2
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.v2.chat(
model="command-a-03-2025",
messages=[
UserChatMessageV2(
content="Tell me about LLMs",
)
],
)
asyncio.run(main())
"""
_response = await self._raw_client.chat(
model=model,
messages=messages,
tools=tools,
strict_tools=strict_tools,
documents=documents,
citation_options=citation_options,
response_format=response_format,
safety_mode=safety_mode,
max_tokens=max_tokens,
stop_sequences=stop_sequences,
temperature=temperature,
seed=seed,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
k=k,
p=p,
logprobs=logprobs,
tool_choice=tool_choice,
thinking=thinking,
priority=priority,
request_options=request_options,
)
return _response.data
async def embed(
self,
*,
model: str,
input_type: EmbedInputType,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
output_dimension: typing.Optional[int] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> EmbedByTypeResponse:
"""
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
model : str
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : EmbedInputType
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Image embeddings are supported with Embed v3.0 and newer models.
inputs : typing.Optional[typing.Sequence[EmbedInput]]
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
max_tokens : typing.Optional[int]
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
output_dimension : typing.Optional[int]
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[V2EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
EmbedByTypeResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.v2.embed(
texts=["hello", "goodbye"],
model="embed-v4.0",
input_type="classification",
embedding_types=["float"],
)
asyncio.run(main())
"""
_response = await self._raw_client.embed(
model=model,
input_type=input_type,
texts=texts,
images=images,
inputs=inputs,
max_tokens=max_tokens,
output_dimension=output_dimension,
embedding_types=embedding_types,
truncate=truncate,
priority=priority,
request_options=request_options,
)
return _response.data
async def rerank(
self,
*,
model: str,
query: str,
documents: typing.Sequence[str],
top_n: typing.Optional[int] = OMIT,
max_tokens_per_doc: typing.Optional[int] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> V2RerankResponse:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
query : str
The search query
documents : typing.Sequence[str]
A list of texts that will be compared to the `query`.
For optimal performance we recommend against sending more than 1,000 documents in a single request.
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
**Note**: structured data should be formatted as YAML strings for best performance.
top_n : typing.Optional[int]
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
max_tokens_per_doc : typing.Optional[int]
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
V2RerankResponse
OK
Examples
--------
import asyncio
from cohere import AsyncClient
client = AsyncClient(
client_name="YOUR_CLIENT_NAME",
token="YOUR_TOKEN",
)
async def main() -> None:
await client.v2.rerank(
documents=[
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
],
query="What is the capital of the United States?",
top_n=3,
model="rerank-v4.0-pro",
)
asyncio.run(main())
"""
_response = await self._raw_client.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
max_tokens_per_doc=max_tokens_per_doc,
priority=priority,
request_options=request_options,
)
return _response.data
================================================
FILE: src/cohere/v2/raw_client.py
================================================
# This file was auto-generated by Fern from our API Definition.
import contextlib
import typing
from json.decoder import JSONDecodeError
from logging import error, warning
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
from ..core.http_sse._api import EventSource
from ..core.parse_error import ParsingError
from ..core.pydantic_utilities import parse_sse_obj
from ..core.request_options import RequestOptions
from ..core.serialization import convert_and_respect_annotation_metadata
from ..core.unchecked_base_model import construct_type
from ..errors.bad_request_error import BadRequestError
from ..errors.client_closed_request_error import ClientClosedRequestError
from ..errors.forbidden_error import ForbiddenError
from ..errors.gateway_timeout_error import GatewayTimeoutError
from ..errors.internal_server_error import InternalServerError
from ..errors.invalid_token_error import InvalidTokenError
from ..errors.not_found_error import NotFoundError
from ..errors.not_implemented_error import NotImplementedError
from ..errors.service_unavailable_error import ServiceUnavailableError
from ..errors.too_many_requests_error import TooManyRequestsError
from ..errors.unauthorized_error import UnauthorizedError
from ..errors.unprocessable_entity_error import UnprocessableEntityError
from ..types.chat_messages import ChatMessages
from ..types.citation_options import CitationOptions
from ..types.embed_by_type_response import EmbedByTypeResponse
from ..types.embed_input import EmbedInput
from ..types.embed_input_type import EmbedInputType
from ..types.embedding_type import EmbeddingType
from ..types.response_format_v2 import ResponseFormatV2
from ..types.thinking import Thinking
from ..types.tool_v2 import ToolV2
from .types.v2chat_request_documents_item import V2ChatRequestDocumentsItem
from .types.v2chat_request_safety_mode import V2ChatRequestSafetyMode
from .types.v2chat_request_tool_choice import V2ChatRequestToolChoice
from .types.v2chat_response import V2ChatResponse
from .types.v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem
from .types.v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode
from .types.v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice
from .types.v2chat_stream_response import V2ChatStreamResponse
from .types.v2embed_request_truncate import V2EmbedRequestTruncate
from .types.v2rerank_response import V2RerankResponse
from pydantic import ValidationError
# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
class RawV2Client:
def __init__(self, *, client_wrapper: SyncClientWrapper):
self._client_wrapper = client_wrapper
@contextlib.contextmanager
def chat_stream(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[HttpResponse[typing.Iterator[V2ChatStreamResponse]]]:
"""
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.Iterator[HttpResponse[typing.Iterator[V2ChatStreamResponse]]]
"""
with self._client_wrapper.httpx_client.stream(
"v2/chat",
method="POST",
json={
"model": model,
"messages": convert_and_respect_annotation_metadata(
object_=messages, annotation=ChatMessages, direction="write"
),
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
),
"strict_tools": strict_tools,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write"
),
"citation_options": convert_and_respect_annotation_metadata(
object_=citation_options, annotation=CitationOptions, direction="write"
),
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormatV2, direction="write"
),
"safety_mode": safety_mode,
"max_tokens": max_tokens,
"stop_sequences": stop_sequences,
"temperature": temperature,
"seed": seed,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"k": k,
"p": p,
"logprobs": logprobs,
"tool_choice": tool_choice,
"thinking": convert_and_respect_annotation_metadata(
object_=thinking, annotation=Thinking, direction="write"
),
"priority": priority,
"stream": True,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
) as _response:
def _stream() -> HttpResponse[typing.Iterator[V2ChatStreamResponse]]:
try:
if 200 <= _response.status_code < 300:
def _iter():
_event_source = EventSource(_response)
for _sse in _event_source.iter_sse():
if _sse.data == "[DONE]":
return
try:
yield typing.cast(
V2ChatStreamResponse,
parse_sse_obj(
sse=_sse,
type_=V2ChatStreamResponse, # type: ignore
),
)
except JSONDecodeError as e:
warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}")
except (TypeError, ValueError, KeyError, AttributeError) as e:
warning(
f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}"
)
except Exception as e:
error(
f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}"
)
return
return HttpResponse(response=_response, data=_iter())
_response.read()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield _stream()
def chat(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[V2ChatResponse]:
"""
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[V2ChatResponse]
"""
_response = self._client_wrapper.httpx_client.request(
"v2/chat",
method="POST",
json={
"model": model,
"messages": convert_and_respect_annotation_metadata(
object_=messages, annotation=ChatMessages, direction="write"
),
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
),
"strict_tools": strict_tools,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write"
),
"citation_options": convert_and_respect_annotation_metadata(
object_=citation_options, annotation=CitationOptions, direction="write"
),
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormatV2, direction="write"
),
"safety_mode": safety_mode,
"max_tokens": max_tokens,
"stop_sequences": stop_sequences,
"temperature": temperature,
"seed": seed,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"k": k,
"p": p,
"logprobs": logprobs,
"tool_choice": tool_choice,
"thinking": convert_and_respect_annotation_metadata(
object_=thinking, annotation=Thinking, direction="write"
),
"priority": priority,
"stream": False,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
V2ChatResponse,
construct_type(
type_=V2ChatResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def embed(
self,
*,
model: str,
input_type: EmbedInputType,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
output_dimension: typing.Optional[int] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[EmbedByTypeResponse]:
"""
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
model : str
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : EmbedInputType
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Image embeddings are supported with Embed v3.0 and newer models.
inputs : typing.Optional[typing.Sequence[EmbedInput]]
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
max_tokens : typing.Optional[int]
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
output_dimension : typing.Optional[int]
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[V2EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[EmbedByTypeResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v2/embed",
method="POST",
json={
"texts": texts,
"images": images,
"model": model,
"input_type": input_type,
"inputs": convert_and_respect_annotation_metadata(
object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write"
),
"max_tokens": max_tokens,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
"truncate": truncate,
"priority": priority,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedByTypeResponse,
construct_type(
type_=EmbedByTypeResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
def rerank(
self,
*,
model: str,
query: str,
documents: typing.Sequence[str],
top_n: typing.Optional[int] = OMIT,
max_tokens_per_doc: typing.Optional[int] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> HttpResponse[V2RerankResponse]:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
query : str
The search query
documents : typing.Sequence[str]
A list of texts that will be compared to the `query`.
For optimal performance we recommend against sending more than 1,000 documents in a single request.
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
**Note**: structured data should be formatted as YAML strings for best performance.
top_n : typing.Optional[int]
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
max_tokens_per_doc : typing.Optional[int]
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
HttpResponse[V2RerankResponse]
OK
"""
_response = self._client_wrapper.httpx_client.request(
"v2/rerank",
method="POST",
json={
"model": model,
"query": query,
"documents": documents,
"top_n": top_n,
"max_tokens_per_doc": max_tokens_per_doc,
"priority": priority,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
V2RerankResponse,
construct_type(
type_=V2RerankResponse, # type: ignore
object_=_response.json(),
),
)
return HttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
class AsyncRawV2Client:
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self._client_wrapper = client_wrapper
@contextlib.asynccontextmanager
async def chat_stream(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]]:
"""
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Yields
------
typing.AsyncIterator[AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]]
"""
async with self._client_wrapper.httpx_client.stream(
"v2/chat",
method="POST",
json={
"model": model,
"messages": convert_and_respect_annotation_metadata(
object_=messages, annotation=ChatMessages, direction="write"
),
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
),
"strict_tools": strict_tools,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write"
),
"citation_options": convert_and_respect_annotation_metadata(
object_=citation_options, annotation=CitationOptions, direction="write"
),
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormatV2, direction="write"
),
"safety_mode": safety_mode,
"max_tokens": max_tokens,
"stop_sequences": stop_sequences,
"temperature": temperature,
"seed": seed,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"k": k,
"p": p,
"logprobs": logprobs,
"tool_choice": tool_choice,
"thinking": convert_and_respect_annotation_metadata(
object_=thinking, annotation=Thinking, direction="write"
),
"priority": priority,
"stream": True,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
) as _response:
async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamResponse]]:
try:
if 200 <= _response.status_code < 300:
async def _iter():
_event_source = EventSource(_response)
async for _sse in _event_source.aiter_sse():
if _sse.data == "[DONE]":
return
try:
yield typing.cast(
V2ChatStreamResponse,
parse_sse_obj(
sse=_sse,
type_=V2ChatStreamResponse, # type: ignore
),
)
except JSONDecodeError as e:
warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}")
except (TypeError, ValueError, KeyError, AttributeError) as e:
warning(
f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}"
)
except Exception as e:
error(
f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}"
)
return
return AsyncHttpResponse(response=_response, data=_iter())
await _response.aread()
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.text
)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code,
headers=dict(_response.headers),
body=_response.json(),
cause=e,
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
yield await _stream()
async def chat(
self,
*,
model: str,
messages: ChatMessages,
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
strict_tools: typing.Optional[bool] = OMIT,
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
citation_options: typing.Optional[CitationOptions] = OMIT,
response_format: typing.Optional[ResponseFormatV2] = OMIT,
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
temperature: typing.Optional[float] = OMIT,
seed: typing.Optional[int] = OMIT,
frequency_penalty: typing.Optional[float] = OMIT,
presence_penalty: typing.Optional[float] = OMIT,
k: typing.Optional[int] = OMIT,
p: typing.Optional[float] = OMIT,
logprobs: typing.Optional[bool] = OMIT,
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
thinking: typing.Optional[Thinking] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[V2ChatResponse]:
"""
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
Parameters
----------
model : str
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models).
messages : ChatMessages
tools : typing.Optional[typing.Sequence[ToolV2]]
A list of tools (functions) available to the model. The model response may contain 'tool_calls' to the specified tools.
Learn more in the [Tool Use guide](https://docs.cohere.com/docs/tools).
strict_tools : typing.Optional[bool]
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
**Note**: The first few requests with a new set of tools will take longer to process.
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
citation_options : typing.Optional[CitationOptions]
response_format : typing.Optional[ResponseFormatV2]
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
When `OFF` is specified, the safety instruction will be omitted.
Safety modes are not yet configurable in combination with `tools` and `documents` parameters.
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
max_tokens : typing.Optional[int]
The maximum number of output tokens the model will generate in the response. If not set, `max_tokens` defaults to the model's maximum output token limit. You can find the maximum output token limits for each model in the [model documentation](https://docs.cohere.com/docs/models).
**Note**: Setting a low value may result in incomplete generations. In such cases, the `finish_reason` field in the response will be set to `"MAX_TOKENS"`.
**Note**: If `max_tokens` is set higher than the model's maximum output token limit, the generation will be capped at that model-specific maximum limit.
stop_sequences : typing.Optional[typing.Sequence[str]]
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
temperature : typing.Optional[float]
Defaults to `0.3`.
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
Randomness can be further maximized by increasing the value of the `p` parameter.
seed : typing.Optional[int]
If specified, the backend will make a best effort to sample tokens
deterministically, such that repeated requests with the same
seed and parameters should return the same result. However,
determinism cannot be totally guaranteed.
frequency_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
presence_penalty : typing.Optional[float]
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
k : typing.Optional[int]
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
Defaults to `0`, min value of `0`, max value of `500`.
p : typing.Optional[float]
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
logprobs : typing.Optional[bool]
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
tool_choice : typing.Optional[V2ChatRequestToolChoice]
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
thinking : typing.Optional[Thinking]
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[V2ChatResponse]
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/chat",
method="POST",
json={
"model": model,
"messages": convert_and_respect_annotation_metadata(
object_=messages, annotation=ChatMessages, direction="write"
),
"tools": convert_and_respect_annotation_metadata(
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
),
"strict_tools": strict_tools,
"documents": convert_and_respect_annotation_metadata(
object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write"
),
"citation_options": convert_and_respect_annotation_metadata(
object_=citation_options, annotation=CitationOptions, direction="write"
),
"response_format": convert_and_respect_annotation_metadata(
object_=response_format, annotation=ResponseFormatV2, direction="write"
),
"safety_mode": safety_mode,
"max_tokens": max_tokens,
"stop_sequences": stop_sequences,
"temperature": temperature,
"seed": seed,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"k": k,
"p": p,
"logprobs": logprobs,
"tool_choice": tool_choice,
"thinking": convert_and_respect_annotation_metadata(
object_=thinking, annotation=Thinking, direction="write"
),
"priority": priority,
"stream": False,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
V2ChatResponse,
construct_type(
type_=V2ChatResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def embed(
self,
*,
model: str,
input_type: EmbedInputType,
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
max_tokens: typing.Optional[int] = OMIT,
output_dimension: typing.Optional[int] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[EmbedByTypeResponse]:
"""
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
Parameters
----------
model : str
ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed).
input_type : EmbedInputType
texts : typing.Optional[typing.Sequence[str]]
An array of strings for the model to embed. Maximum number of texts per call is `96`.
images : typing.Optional[typing.Sequence[str]]
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg`, `image/png`, `image/webp`, or `image/gif` format and has a maximum size of 5MB.
Image embeddings are supported with Embed v3.0 and newer models.
inputs : typing.Optional[typing.Sequence[EmbedInput]]
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
max_tokens : typing.Optional[int]
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
output_dimension : typing.Optional[int]
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
embedding_types : typing.Optional[typing.Sequence[EmbeddingType]]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Supported with all Embed models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Supported with Embed v3.0 and newer Embed models.
* `"binary"`: Use this when you want to get back signed binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Supported with Embed v3.0 and newer Embed models.
* `"base64"`: Use this when you want to get back base64 embeddings. Supported with Embed v3.0 and newer Embed models.
truncate : typing.Optional[V2EmbedRequestTruncate]
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[EmbedByTypeResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/embed",
method="POST",
json={
"texts": texts,
"images": images,
"model": model,
"input_type": input_type,
"inputs": convert_and_respect_annotation_metadata(
object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write"
),
"max_tokens": max_tokens,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
"truncate": truncate,
"priority": priority,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
EmbedByTypeResponse,
construct_type(
type_=EmbedByTypeResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
async def rerank(
self,
*,
model: str,
query: str,
documents: typing.Sequence[str],
top_n: typing.Optional[int] = OMIT,
max_tokens_per_doc: typing.Optional[int] = OMIT,
priority: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> AsyncHttpResponse[V2RerankResponse]:
"""
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
query : str
The search query
documents : typing.Sequence[str]
A list of texts that will be compared to the `query`.
For optimal performance we recommend against sending more than 1,000 documents in a single request.
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
**Note**: structured data should be formatted as YAML strings for best performance.
top_n : typing.Optional[int]
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
max_tokens_per_doc : typing.Optional[int]
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
priority : typing.Optional[int]
Controls how early the request is handled. Lower numbers indicate higher priority (default: 0, the highest). When the system is under load, higher-priority requests are processed first and are the least likely to be dropped.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
AsyncHttpResponse[V2RerankResponse]
OK
"""
_response = await self._client_wrapper.httpx_client.request(
"v2/rerank",
method="POST",
json={
"model": model,
"query": query,
"documents": documents,
"top_n": top_n,
"max_tokens_per_doc": max_tokens_per_doc,
"priority": priority,
},
headers={
"content-type": "application/json",
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
_data = typing.cast(
V2RerankResponse,
construct_type(
type_=V2RerankResponse, # type: ignore
object_=_response.json(),
),
)
return AsyncHttpResponse(response=_response, data=_data)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Any,
construct_type(
type_=typing.Any, # type: ignore
object_=_response.json(),
),
),
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
except ValidationError as e:
raise ParsingError(
status_code=_response.status_code, headers=dict(_response.headers), body=_response.json(), cause=e
)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
================================================
FILE: src/cohere/v2/types/__init__.py
================================================
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .v2chat_request_documents_item import V2ChatRequestDocumentsItem
from .v2chat_request_safety_mode import V2ChatRequestSafetyMode
from .v2chat_request_tool_choice import V2ChatRequestToolChoice
from .v2chat_response import V2ChatResponse
from .v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem
from .v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode
from .v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice
from .v2chat_stream_response import (
CitationEndV2ChatStreamResponse,
CitationStartV2ChatStreamResponse,
ContentDeltaV2ChatStreamResponse,
ContentEndV2ChatStreamResponse,
ContentStartV2ChatStreamResponse,
DebugV2ChatStreamResponse,
MessageEndV2ChatStreamResponse,
MessageStartV2ChatStreamResponse,
ToolCallDeltaV2ChatStreamResponse,
ToolCallEndV2ChatStreamResponse,
ToolCallStartV2ChatStreamResponse,
ToolPlanDeltaV2ChatStreamResponse,
V2ChatStreamResponse,
)
from .v2embed_request_truncate import V2EmbedRequestTruncate
from .v2rerank_response import V2RerankResponse
from .v2rerank_response_results_item import V2RerankResponseResultsItem
_dynamic_imports: typing.Dict[str, str] = {
"CitationEndV2ChatStreamResponse": ".v2chat_stream_response",
"CitationStartV2ChatStreamResponse": ".v2chat_stream_response",
"ContentDeltaV2ChatStreamResponse": ".v2chat_stream_response",
"ContentEndV2ChatStreamResponse": ".v2chat_stream_response",
"ContentStartV2ChatStreamResponse": ".v2chat_stream_response",
"DebugV2ChatStreamResponse": ".v2chat_stream_response",
"MessageEndV2ChatStreamResponse": ".v2chat_stream_response",
"MessageStartV2ChatStreamResponse": ".v2chat_stream_response",
"ToolCallDeltaV2ChatStreamResponse": ".v2chat_stream_response",
"ToolCallEndV2ChatStreamResponse": ".v2chat_stream_response",
"ToolCallStartV2ChatStreamResponse": ".v2chat_stream_response",
"ToolPlanDeltaV2ChatStreamResponse": ".v2chat_stream_response",
"V2ChatRequestDocumentsItem": ".v2chat_request_documents_item",
"V2ChatRequestSafetyMode": ".v2chat_request_safety_mode",
"V2ChatRequestToolChoice": ".v2chat_request_tool_choice",
"V2ChatResponse": ".v2chat_response",
"V2ChatStreamRequestDocumentsItem": ".v2chat_stream_request_documents_item",
"V2ChatStreamRequestSafetyMode": ".v2chat_stream_request_safety_mode",
"V2ChatStreamRequestToolChoice": ".v2chat_stream_request_tool_choice",
"V2ChatStreamResponse": ".v2chat_stream_response",
"V2EmbedRequestTruncate": ".v2embed_request_truncate",
"V2RerankResponse": ".v2rerank_response",
"V2RerankResponseResultsItem": ".v2rerank_response_results_item",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"CitationEndV2ChatStreamResponse",
"CitationStartV2ChatStreamResponse",
"ContentDeltaV2ChatStreamResponse",
"ContentEndV2ChatStreamResponse",
"ContentStartV2ChatStreamResponse",
"DebugV2ChatStreamResponse",
"MessageEndV2ChatStreamResponse",
"MessageStartV2ChatStreamResponse",
"ToolCallDeltaV2ChatStreamResponse",
"ToolCallEndV2ChatStreamResponse",
"ToolCallStartV2ChatStreamResponse",
"ToolPlanDeltaV2ChatStreamResponse",
"V2ChatRequestDocumentsItem",
"V2ChatRequestSafetyMode",
"V2ChatRequestToolChoice",
"V2ChatResponse",
"V2ChatStreamRequestDocumentsItem",
"V2ChatStreamRequestSafetyMode",
"V2ChatStreamRequestToolChoice",
"V2ChatStreamResponse",
"V2EmbedRequestTruncate",
"V2RerankResponse",
"V2RerankResponseResultsItem",
]
================================================
FILE: src/cohere/v2/types/v2chat_request_documents_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ...types.document import Document
V2ChatRequestDocumentsItem = typing.Union[str, Document]
================================================
FILE: src/cohere/v2/types/v2chat_request_safety_mode.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
V2ChatRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "OFF"], typing.Any]
================================================
FILE: src/cohere/v2/types/v2chat_request_tool_choice.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
V2ChatRequestToolChoice = typing.Union[typing.Literal["REQUIRED", "NONE"], typing.Any]
================================================
FILE: src/cohere/v2/types/v2chat_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from ...types.assistant_message_response import AssistantMessageResponse
from ...types.chat_finish_reason import ChatFinishReason
from ...types.logprob_item import LogprobItem
from ...types.usage import Usage
class V2ChatResponse(UncheckedBaseModel):
id: str = pydantic.Field()
"""
Unique identifier for the generated reply. Useful for submitting feedback.
"""
finish_reason: ChatFinishReason
message: AssistantMessageResponse
usage: typing.Optional[Usage] = None
logprobs: typing.Optional[typing.List[LogprobItem]] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/v2/types/v2chat_stream_request_documents_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
from ...types.document import Document
V2ChatStreamRequestDocumentsItem = typing.Union[str, Document]
================================================
FILE: src/cohere/v2/types/v2chat_stream_request_safety_mode.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
V2ChatStreamRequestSafetyMode = typing.Union[typing.Literal["CONTEXTUAL", "STRICT", "OFF"], typing.Any]
================================================
FILE: src/cohere/v2/types/v2chat_stream_request_tool_choice.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
V2ChatStreamRequestToolChoice = typing.Union[typing.Literal["REQUIRED", "NONE"], typing.Any]
================================================
FILE: src/cohere/v2/types/v2chat_stream_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
from __future__ import annotations
import typing
import pydantic
import typing_extensions
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel, UnionMetadata
from ...types.chat_content_delta_event_delta import ChatContentDeltaEventDelta
from ...types.chat_content_start_event_delta import ChatContentStartEventDelta
from ...types.chat_message_end_event_delta import ChatMessageEndEventDelta
from ...types.chat_message_start_event_delta import ChatMessageStartEventDelta
from ...types.chat_tool_call_delta_event_delta import ChatToolCallDeltaEventDelta
from ...types.chat_tool_call_start_event_delta import ChatToolCallStartEventDelta
from ...types.chat_tool_plan_delta_event_delta import ChatToolPlanDeltaEventDelta
from ...types.citation_start_event_delta import CitationStartEventDelta
from ...types.logprob_item import LogprobItem
class MessageStartV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["message-start"] = "message-start"
id: typing.Optional[str] = None
delta: typing.Optional[ChatMessageStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ContentStartV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["content-start"] = "content-start"
index: typing.Optional[int] = None
delta: typing.Optional[ChatContentStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ContentDeltaV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["content-delta"] = "content-delta"
index: typing.Optional[int] = None
delta: typing.Optional[ChatContentDeltaEventDelta] = None
logprobs: typing.Optional[LogprobItem] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ContentEndV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["content-end"] = "content-end"
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolPlanDeltaV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["tool-plan-delta"] = "tool-plan-delta"
delta: typing.Optional[ChatToolPlanDeltaEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolCallStartV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["tool-call-start"] = "tool-call-start"
index: typing.Optional[int] = None
delta: typing.Optional[ChatToolCallStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolCallDeltaV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["tool-call-delta"] = "tool-call-delta"
index: typing.Optional[int] = None
delta: typing.Optional[ChatToolCallDeltaEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class ToolCallEndV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["tool-call-end"] = "tool-call-end"
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class CitationStartV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["citation-start"] = "citation-start"
index: typing.Optional[int] = None
delta: typing.Optional[CitationStartEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class CitationEndV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["citation-end"] = "citation-end"
index: typing.Optional[int] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class MessageEndV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["message-end"] = "message-end"
id: typing.Optional[str] = None
delta: typing.Optional[ChatMessageEndEventDelta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
class DebugV2ChatStreamResponse(UncheckedBaseModel):
"""
StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request).
"""
type: typing.Literal["debug"] = "debug"
prompt: typing.Optional[str] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
V2ChatStreamResponse = typing_extensions.Annotated[
typing.Union[
MessageStartV2ChatStreamResponse,
ContentStartV2ChatStreamResponse,
ContentDeltaV2ChatStreamResponse,
ContentEndV2ChatStreamResponse,
ToolPlanDeltaV2ChatStreamResponse,
ToolCallStartV2ChatStreamResponse,
ToolCallDeltaV2ChatStreamResponse,
ToolCallEndV2ChatStreamResponse,
CitationStartV2ChatStreamResponse,
CitationEndV2ChatStreamResponse,
MessageEndV2ChatStreamResponse,
DebugV2ChatStreamResponse,
],
UnionMetadata(discriminant="type"),
]
================================================
FILE: src/cohere/v2/types/v2embed_request_truncate.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
V2EmbedRequestTruncate = typing.Union[typing.Literal["NONE", "START", "END"], typing.Any]
================================================
FILE: src/cohere/v2/types/v2rerank_response.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
from ...types.api_meta import ApiMeta
from .v2rerank_response_results_item import V2RerankResponseResultsItem
class V2RerankResponse(UncheckedBaseModel):
id: typing.Optional[str] = None
results: typing.List[V2RerankResponseResultsItem] = pydantic.Field()
"""
An ordered list of ranked documents
"""
meta: typing.Optional[ApiMeta] = None
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/v2/types/v2rerank_response_results_item.py
================================================
# This file was auto-generated by Fern from our API Definition.
import typing
import pydantic
from ...core.pydantic_utilities import IS_PYDANTIC_V2
from ...core.unchecked_base_model import UncheckedBaseModel
class V2RerankResponseResultsItem(UncheckedBaseModel):
index: int = pydantic.Field()
"""
Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance)
"""
relevance_score: float = pydantic.Field()
"""
Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45
"""
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
smart_union = True
extra = pydantic.Extra.allow
================================================
FILE: src/cohere/version.py
================================================
from importlib import metadata
__version__ = metadata.version("cohere")
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/embed_job.jsonl
================================================
{"text": "The quick brown fox jumps over the lazy dog"}
================================================
FILE: tests/test_async_client.py
================================================
import os
import unittest
import cohere
from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \
ToolParameterDefinitionsValue, ToolResult, UserMessage, ChatbotMessage
package_dir = os.path.dirname(os.path.abspath(__file__))
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
class TestClient(unittest.IsolatedAsyncioTestCase):
co: cohere.AsyncClient
def setUp(self) -> None:
self.co = cohere.AsyncClient(timeout=10000)
async def test_token_falls_back_on_env_variable(self) -> None:
cohere.AsyncClient(api_key=None)
cohere.AsyncClient(None)
async def test_context_manager(self) -> None:
async with cohere.AsyncClient(api_key="xxx") as client:
self.assertIsNotNone(client)
async def test_chat(self) -> None:
chat = await self.co.chat(
model="command-a-03-2025",
chat_history=[
UserMessage(
message="Who discovered gravity?"),
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
)
print(chat)
async def test_chat_stream(self) -> None:
stream = self.co.chat_stream(
model="command-a-03-2025",
chat_history=[
UserMessage(
message="Who discovered gravity?"),
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
)
events = set()
async for chat_event in stream:
events.add(chat_event.event_type)
if chat_event.event_type == "text-generation":
print(chat_event.text)
self.assertTrue("text-generation" in events)
self.assertTrue("stream-start" in events)
self.assertTrue("stream-end" in events)
async def test_stream_equals_true(self) -> None:
with self.assertRaises(ValueError):
await self.co.chat(
stream=True, # type: ignore
message="What year was he born?",
)
async def test_deprecated_fn(self) -> None:
with self.assertRaises(ValueError):
await self.co.check_api_key("dummy", dummy="dummy") # type: ignore
async def test_moved_fn(self) -> None:
with self.assertRaises(ValueError):
await self.co.list_connectors("dummy", dummy="dummy") # type: ignore
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
async def test_embed(self) -> None:
response = await self.co.embed(
texts=['hello', 'goodbye'],
model='embed-english-v3.0',
input_type="classification"
)
print(response)
async def test_embed_batch_types(self) -> None:
# batch more than 96 texts
response = await self.co.embed(
texts=['hello']*100,
model='embed-english-v3.0',
input_type="classification",
embedding_types=["float", "int8", "uint8", "binary", "ubinary"]
)
if response.response_type == "embeddings_by_type":
self.assertEqual(len(response.texts or []), 100)
self.assertEqual(len(response.embeddings.float_ or []), 100)
self.assertEqual(len(response.embeddings.int8 or []), 100)
self.assertEqual(len(response.embeddings.uint8 or []), 100)
self.assertEqual(len(response.embeddings.binary or []), 100)
self.assertEqual(len(response.embeddings.ubinary or []), 100)
else:
self.fail("Expected embeddings_by_type response type")
print(response)
async def test_embed_batch_v1(self) -> None:
# batch more than 96 texts
response = await self.co.embed(
texts=['hello']*100,
model='embed-english-v3.0',
input_type="classification",
)
if response.response_type == "embeddings_floats":
self.assertEqual(len(response.embeddings), 100)
else:
self.fail("Expected embeddings_floats response type")
print(response)
@unittest.skip("temp")
async def test_embed_job_crud(self) -> None:
dataset = await self.co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
result = await self.co.wait(dataset)
self.assertEqual(result.dataset.validation_status, "validated")
# start an embed job
job = await self.co.embed_jobs.create(
dataset_id=dataset.id or "",
input_type="search_document",
model='embed-english-v3.0')
print(job)
# list embed jobs
my_embed_jobs = await self.co.embed_jobs.list()
print(my_embed_jobs)
emb_result = await self.co.wait(job)
self.assertEqual(emb_result.status, "complete")
await self.co.embed_jobs.cancel(job.job_id)
await self.co.datasets.delete(dataset.id or "")
async def test_rerank(self) -> None:
docs = [
'Carson City is the capital city of the American state of Nevada.',
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
response = await self.co.rerank(
model='rerank-v3.5',
query='What is the capital of the United States?',
documents=docs,
top_n=3,
)
print(response)
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
async def test_datasets_crud(self) -> None:
my_dataset = await self.co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
print(my_dataset)
my_datasets = await self.co.datasets.list()
print(my_datasets)
dataset = await self.co.datasets.get(my_dataset.id or "")
print(dataset)
await self.co.datasets.delete(my_dataset.id or "")
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
async def test_save_load(self) -> None:
my_dataset = await self.co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
result = await self.co.wait(my_dataset)
self.co.utils.save_dataset(result.dataset, "dataset.jsonl")
# assert files equal
self.assertTrue(os.path.exists("dataset.jsonl"))
self.assertEqual(open(embed_job, 'rb').read(),
open("dataset.jsonl", 'rb').read())
print(result)
await self.co.datasets.delete(my_dataset.id or "")
async def test_tokenize(self) -> None:
response = await self.co.tokenize(
text='tokenize me! :D',
model="command-a-03-2025",
offline=False,
)
print(response)
async def test_detokenize(self) -> None:
response = await self.co.detokenize(
tokens=[10104, 12221, 1315, 34, 1420, 69],
model="command-a-03-2025",
offline=False,
)
print(response)
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
async def test_tool_use(self) -> None:
tools = [
Tool(
name="sales_database",
description="Connects to a database about sales volumes",
parameter_definitions={
"day": ToolParameterDefinitionsValue(
description="Retrieves sales data from this day, formatted as YYYY-MM-DD.",
type="str",
required=True
)}
)
]
tool_parameters_response = await self.co.chat(
message="How good were the sales on September 29 2023?",
tools=tools,
model="command-nightly",
preamble="""
## Task Description
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
## Style Guide
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
"""
)
if tool_parameters_response.tool_calls is not None:
self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
else:
raise ValueError("Expected tool calls to be present")
local_tools = {
"sales_database": lambda day: {
"number_of_sales": 120,
"total_revenue": 48500,
"average_sale_value": 404.17,
"date": "2023-09-29"
}
}
tool_results = []
for tool_call in tool_parameters_response.tool_calls:
output = local_tools[tool_call.name](**tool_call.parameters)
outputs = [output]
tool_results.append(ToolResult(
call=tool_call,
outputs=outputs
))
cited_response = await self.co.chat(
message="How good were the sales on September 29?",
tools=tools,
tool_results=tool_results,
force_single_step=True,
model="command-a-03-2025",
)
self.assertEqual(cited_response.documents, [
{
"average_sale_value": "404.17",
"date": "2023-09-29",
"id": "sales_database:0:0",
"number_of_sales": "120",
"total_revenue": "48500",
}
])
async def test_local_tokenize(self) -> None:
response = await self.co.tokenize(
model="command-a-03-2025",
text="tokenize me! :D"
)
print(response)
async def test_local_detokenize(self) -> None:
response = await self.co.detokenize(
model="command-a-03-2025",
tokens=[10104, 12221, 1315, 34, 1420, 69]
)
print(response)
async def test_tokenize_async_context_with_sync_client(self) -> None:
# Test that the sync client can be used in an async context.
co = cohere.Client(timeout=10000)
print(co.tokenize(model="command-a-03-2025", text="tokenize me! :D"))
print(co.detokenize(model="command-a-03-2025", tokens=[
10104, 12221, 1315, 34, 1420, 69]))
================================================
FILE: tests/test_aws_client_unit.py
================================================
"""
Unit tests (mocked, no AWS credentials needed) for AWS client fixes.
Covers:
- Fix 1: SigV4 signing uses the correct host header after URL rewrite
- Fix 2: cohere_aws.Client conditionally initializes based on mode
- Fix 3: embed() accepts and passes output_dimension and embedding_types
"""
import inspect
import json
import os
import unittest
from unittest.mock import MagicMock, patch
import httpx
from cohere.manually_maintained.cohere_aws.mode import Mode
class TestSigV4HostHeader(unittest.TestCase):
"""Fix 1: The headers dict passed to AWSRequest for SigV4 signing must
contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com."""
def test_sigv4_signs_with_correct_host(self) -> None:
captured_aws_request_kwargs: dict = {}
mock_aws_request_cls = MagicMock()
def capture_aws_request(**kwargs): # type: ignore
captured_aws_request_kwargs.update(kwargs)
mock_req = MagicMock()
mock_req.prepare.return_value = MagicMock(
headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"}
)
return mock_req
mock_aws_request_cls.side_effect = capture_aws_request
mock_botocore = MagicMock()
mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls
mock_botocore.auth.SigV4Auth.return_value = MagicMock()
mock_boto3 = MagicMock()
mock_session = MagicMock()
mock_session.region_name = "us-east-1"
mock_session.get_credentials.return_value = MagicMock()
mock_boto3.Session.return_value = mock_session
with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3):
from cohere.aws_client import map_request_to_bedrock
hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1")
request = httpx.Request(
method="POST",
url="https://api.cohere.com/v1/chat",
headers={"connection": "keep-alive"},
json={"model": "cohere.command-r-plus-v1:0", "message": "hello"},
)
self.assertEqual(request.url.host, "api.cohere.com")
hook(request)
self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url))
signed_headers = captured_aws_request_kwargs["headers"]
self.assertEqual(
signed_headers["host"],
"bedrock-runtime.us-east-1.amazonaws.com",
)
class TestModeConditionalInit(unittest.TestCase):
"""Fix 2: cohere_aws.Client should initialize different boto3 clients
depending on mode, and default to SAGEMAKER for backwards compat."""
def test_sagemaker_mode_creates_sagemaker_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
from cohere.manually_maintained.cohere_aws.client import Client
client = Client(aws_region="us-east-1")
self.assertEqual(client.mode, Mode.SAGEMAKER)
service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("sagemaker-runtime", service_names)
self.assertIn("sagemaker", service_names)
self.assertNotIn("bedrock-runtime", service_names)
self.assertNotIn("bedrock", service_names)
mock_sagemaker.Session.assert_called_once()
def test_bedrock_mode_creates_bedrock_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}):
from cohere.manually_maintained.cohere_aws.client import Client
client = Client(aws_region="us-west-2", mode=Mode.BEDROCK)
self.assertEqual(client.mode, Mode.BEDROCK)
service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("bedrock-runtime", service_names)
self.assertIn("bedrock", service_names)
self.assertNotIn("sagemaker-runtime", service_names)
self.assertNotIn("sagemaker", service_names)
mock_sagemaker.Session.assert_not_called()
def test_default_mode_is_sagemaker(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client
sig = inspect.signature(Client.__init__)
self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER)
class TestEmbedV4Params(unittest.TestCase):
"""Fix 3: embed() should accept output_dimension and embedding_types,
pass them through to the request body, and strip them when None."""
@staticmethod
def _make_bedrock_client(): # type: ignore
mock_boto3 = MagicMock()
mock_botocore = MagicMock()
captured_body: dict = {}
def fake_invoke_model(**kwargs): # type: ignore
captured_body.update(json.loads(kwargs["body"]))
mock_body = MagicMock()
mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode()
return {"body": mock_body}
mock_bedrock_client = MagicMock()
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model
def fake_boto3_client(service_name, **kwargs): # type: ignore
if service_name == "bedrock-runtime":
return mock_bedrock_client
return MagicMock()
mock_boto3.client.side_effect = fake_boto3_client
return mock_boto3, mock_botocore, captured_body
def test_embed_accepts_new_params(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client
sig = inspect.signature(Client.embed)
self.assertIn("output_dimension", sig.parameters)
self.assertIn("embedding_types", sig.parameters)
self.assertIsNone(sig.parameters["output_dimension"].default)
self.assertIsNone(sig.parameters["embedding_types"].default)
def test_embed_passes_params_to_bedrock(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
from cohere.manually_maintained.cohere_aws.client import Client
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
output_dimension=256,
embedding_types=["float", "int8"],
)
self.assertEqual(captured_body["output_dimension"], 256)
self.assertEqual(captured_body["embedding_types"], ["float", "int8"])
def test_embed_omits_none_params(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
from cohere.manually_maintained.cohere_aws.client import Client
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
)
self.assertNotIn("output_dimension", captured_body)
self.assertNotIn("embedding_types", captured_body)
def test_embed_with_embedding_types_returns_dict(self) -> None:
"""When embedding_types is specified, the API returns embeddings as a dict.
The client should return that dict rather than wrapping it in Embeddings."""
mock_boto3 = MagicMock()
mock_botocore = MagicMock()
by_type_embeddings = {"float": [[0.1, 0.2]], "int8": [[1, 2]]}
def fake_invoke_model(**kwargs): # type: ignore
mock_body = MagicMock()
mock_body.read.return_value = json.dumps({
"embeddings": by_type_embeddings,
"response_type": "embeddings_by_type",
}).encode()
return {"body": mock_body}
mock_bedrock_client = MagicMock()
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model
def fake_boto3_client(service_name, **kwargs): # type: ignore
if service_name == "bedrock-runtime":
return mock_bedrock_client
return MagicMock()
mock_boto3.client.side_effect = fake_boto3_client
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
from cohere.manually_maintained.cohere_aws.client import Client
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
result = client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
embedding_types=["float", "int8"],
)
self.assertIsInstance(result, dict)
self.assertEqual(result, by_type_embeddings)
================================================
FILE: tests/test_bedrock_client.py
================================================
import os
import unittest
import typing
import cohere
aws_access_key = os.getenv("AWS_ACCESS_KEY")
aws_secret_key = os.getenv("AWS_SECRET_KEY")
aws_session_token = os.getenv("AWS_SESSION_TOKEN")
aws_region = os.getenv("AWS_REGION")
endpoint_type = os.getenv("ENDPOINT_TYPE")
def _setup_boto3_env():
"""Bridge custom test env vars to standard boto3 credential env vars."""
if aws_access_key:
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key
if aws_secret_key:
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_key
if aws_session_token:
os.environ["AWS_SESSION_TOKEN"] = aws_session_token
@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
class TestClient(unittest.TestCase):
platform: str = "bedrock"
models: typing.Dict[str, str] = {
"chat_model": "cohere.command-r-plus-v1:0",
"embed_model": "cohere.embed-multilingual-v3",
"generate_model": "cohere.command-text-v14",
}
def setUp(self) -> None:
self.client = cohere.BedrockClient(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
)
def test_rerank(self) -> None:
if self.platform != "sagemaker":
self.skipTest("Only sagemaker supports rerank")
docs = [
'Carson City is the capital city of the American state of Nevada.',
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
response = self.client.rerank(
model=self.models["rerank_model"],
query='What is the capital of the United States?',
documents=docs,
top_n=3,
)
self.assertEqual(len(response.results), 3)
def test_embed(self) -> None:
response = self.client.embed(
model=self.models["embed_model"],
texts=["I love Cohere!"],
input_type="search_document",
)
print(response)
def test_generate(self) -> None:
response = self.client.generate(
model=self.models["generate_model"],
prompt='Please explain to me how LLMs work',
)
print(response)
def test_generate_stream(self) -> None:
response = self.client.generate_stream(
model=self.models["generate_model"],
prompt='Please explain to me how LLMs work',
)
for event in response:
print(event)
if event.event_type == "text-generation":
print(event.text, end='')
def test_chat(self) -> None:
response = self.client.chat(
model=self.models["chat_model"],
message='Please explain to me how LLMs work',
)
print(response)
self.assertIsNotNone(response.text)
self.assertIsNotNone(response.generation_id)
self.assertIsNotNone(response.finish_reason)
self.assertIsNotNone(response.meta)
if response.meta is not None:
self.assertIsNotNone(response.meta.tokens)
if response.meta.tokens is not None:
self.assertIsNotNone(response.meta.tokens.input_tokens)
self.assertIsNotNone(response.meta.tokens.output_tokens)
self.assertIsNotNone(response.meta.billed_units)
if response.meta.billed_units is not None:
self.assertIsNotNone(response.meta.billed_units.input_tokens)
self.assertIsNotNone(response.meta.billed_units.input_tokens)
def test_chat_stream(self) -> None:
response_types = set()
response = self.client.chat_stream(
model=self.models["chat_model"],
message='Please explain to me how LLMs work',
)
for event in response:
response_types.add(event.event_type)
if event.event_type == "text-generation":
print(event.text, end='')
self.assertIsNotNone(event.text)
if event.event_type == "stream-end":
self.assertIsNotNone(event.finish_reason)
self.assertIsNotNone(event.response)
self.assertIsNotNone(event.response.text)
self.assertSetEqual(response_types, {"text-generation", "stream-end"})
@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
class TestBedrockClientV2(unittest.TestCase):
"""Integration tests for BedrockClientV2 (httpx-based).
Fix 1 validation: If these pass, SigV4 signing uses the correct host header,
since the request would fail with a signature mismatch otherwise.
"""
def setUp(self) -> None:
self.client = cohere.BedrockClientV2(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
)
def test_embed(self) -> None:
response = self.client.embed(
model="cohere.embed-multilingual-v3",
texts=["I love Cohere!"],
input_type="search_document",
embedding_types=["float"],
)
self.assertIsNotNone(response)
def test_embed_with_output_dimension(self) -> None:
response = self.client.embed(
model="cohere.embed-english-v3",
texts=["I love Cohere!"],
input_type="search_document",
embedding_types=["float"],
output_dimension=256,
)
self.assertIsNotNone(response)
@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
class TestCohereAwsBedrockClient(unittest.TestCase):
"""Integration tests for cohere_aws.Client in Bedrock mode (boto3-based).
Validates:
- Fix 2: Client can be initialized with mode=BEDROCK without importing sagemaker
- Fix 3: embed() accepts output_dimension and embedding_types
"""
client: typing.Any = None
@classmethod
def setUpClass(cls) -> None:
_setup_boto3_env()
from cohere.manually_maintained.cohere_aws.client import Client
from cohere.manually_maintained.cohere_aws.mode import Mode
cls.client = Client(aws_region=aws_region, mode=Mode.BEDROCK)
def test_client_is_bedrock_mode(self) -> None:
from cohere.manually_maintained.cohere_aws.mode import Mode
self.assertEqual(self.client.mode, Mode.BEDROCK)
def test_embed(self) -> None:
response = self.client.embed(
texts=["I love Cohere!"],
input_type="search_document",
model_id="cohere.embed-multilingual-v3",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
self.assertGreater(len(response.embeddings), 0)
def test_embed_with_embedding_types(self) -> None:
response = self.client.embed(
texts=["I love Cohere!"],
input_type="search_document",
model_id="cohere.embed-multilingual-v3",
embedding_types=["float"],
)
self.assertIsNotNone(response)
# When embedding_types is passed, the response is a raw dict
self.assertIsInstance(response, dict)
self.assertIn("float", response)
def test_embed_with_output_dimension(self) -> None:
response = self.client.embed(
texts=["I love Cohere!"],
input_type="search_document",
model_id="cohere.embed-english-v3",
output_dimension=256,
embedding_types=["float"],
)
self.assertIsNotNone(response)
# When embedding_types is passed, the response is a raw dict
self.assertIsInstance(response, dict)
self.assertIn("float", response)
def test_embed_without_new_params(self) -> None:
"""Backwards compat: embed() still works without the new v4 params."""
response = self.client.embed(
texts=["I love Cohere!"],
input_type="search_document",
model_id="cohere.embed-multilingual-v3",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
================================================
FILE: tests/test_client.py
================================================
import json
import os
import unittest
import cohere
from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \
ToolParameterDefinitionsValue, ToolResult, ChatbotMessage, UserMessage, JsonObjectResponseFormat
co = cohere.Client(timeout=10000)
package_dir = os.path.dirname(os.path.abspath(__file__))
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
class TestClient(unittest.TestCase):
def test_token_falls_back_on_env_variable(self) -> None:
cohere.Client(api_key=None)
cohere.Client(None)
def test_context_manager(self) -> None:
with cohere.Client(api_key="xxx") as client:
self.assertIsNotNone(client)
def test_chat(self) -> None:
chat = co.chat(
chat_history=[
UserMessage(
message="Who discovered gravity?"),
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
)
print(chat)
def test_chat_stream(self) -> None:
stream = co.chat_stream(
chat_history=[
UserMessage(
message="Who discovered gravity?"),
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
)
events = set()
for chat_event in stream:
events.add(chat_event.event_type)
if chat_event.event_type == "text-generation":
print(chat_event.text)
self.assertTrue("text-generation" in events)
self.assertTrue("stream-start" in events)
self.assertTrue("stream-end" in events)
def test_stream_equals_true(self) -> None:
with self.assertRaises(ValueError):
co.chat(
stream=True, # type: ignore
message="What year was he born?",
)
def test_deprecated_fn(self) -> None:
with self.assertRaises(ValueError):
co.check_api_key("dummy", dummy="dummy") # type: ignore
def test_moved_fn(self) -> None:
with self.assertRaises(ValueError):
co.list_connectors("dummy", dummy="dummy") # type: ignore
def test_embed(self) -> None:
response = co.embed(
texts=['hello', 'goodbye'],
model='embed-english-v3.0',
input_type="classification",
embedding_types=["float", "int8", "uint8", "binary", "ubinary"]
)
if response.response_type == "embeddings_by_type":
self.assertIsNotNone(response.embeddings.float) # type: ignore
self.assertIsNotNone(response.embeddings.float_)
if response.embeddings.float_ is not None:
self.assertEqual(type(response.embeddings.float_[0][0]), float)
if response.embeddings.int8 is not None:
self.assertEqual(type(response.embeddings.int8[0][0]), int)
if response.embeddings.uint8 is not None:
self.assertEqual(type(response.embeddings.uint8[0][0]), int)
if response.embeddings.binary is not None:
self.assertEqual(type(response.embeddings.binary[0][0]), int)
if response.embeddings.ubinary is not None:
self.assertEqual(type(response.embeddings.ubinary[0][0]), int)
print(response)
def test_image_embed(self) -> None:
response = co.embed(
images=['data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII='],
model='embed-multilingual-v3.0',
input_type="image",
embedding_types=["float"]
)
if response.response_type == "embeddings_by_type":
self.assertIsNotNone(response.embeddings.float) # type: ignore
self.assertIsNotNone(response.embeddings.float_)
if response.embeddings.float_ is not None:
self.assertEqual(type(response.embeddings.float_[0][0]), float)
if response.embeddings.int8 is not None:
self.assertEqual(type(response.embeddings.int8[0][0]), int)
if response.embeddings.uint8 is not None:
self.assertEqual(type(response.embeddings.uint8[0][0]), int)
if response.embeddings.binary is not None:
self.assertEqual(type(response.embeddings.binary[0][0]), int)
if response.embeddings.ubinary is not None:
self.assertEqual(type(response.embeddings.ubinary[0][0]), int)
print(response)
def test_embed_batch_types(self) -> None:
# batch more than 96 texts
response = co.embed(
texts=['hello'] * 100,
model='embed-english-v3.0',
input_type="classification",
embedding_types=["float", "int8", "uint8", "binary", "ubinary"]
)
if response.response_type == "embeddings_by_type":
self.assertEqual(len(response.texts or []), 100)
self.assertEqual(len(response.embeddings.float_ or []), 100)
self.assertEqual(len(response.embeddings.int8 or []), 100)
self.assertEqual(len(response.embeddings.uint8 or []), 100)
self.assertEqual(len(response.embeddings.binary or []), 100)
self.assertEqual(len(response.embeddings.ubinary or []), 100)
else:
self.fail("Expected embeddings_by_type response type")
print(response)
def test_embed_batch_v1(self) -> None:
# batch more than 96 texts
response = co.embed(
texts=['hello'] * 100,
model='embed-english-v3.0',
input_type="classification",
)
if response.response_type == "embeddings_floats":
self.assertEqual(len(response.embeddings), 100)
else:
self.fail("Expected embeddings_floats response type")
print(response)
@unittest.skip("temp")
def test_embed_job_crud(self) -> None:
dataset = co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
result = co.wait(dataset)
self.assertEqual(result.dataset.validation_status, "validated")
# start an embed job
job = co.embed_jobs.create(
dataset_id=dataset.id or "",
input_type="search_document",
model='embed-english-v3.0')
print(job)
# list embed jobs
my_embed_jobs = co.embed_jobs.list()
print(my_embed_jobs)
emb_result = co.wait(job)
self.assertEqual(emb_result.status, "complete")
co.embed_jobs.cancel(job.job_id)
co.datasets.delete(dataset.id or "")
def test_rerank(self) -> None:
docs = [
'Carson City is the capital city of the American state of Nevada.',
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
response = co.rerank(
model='rerank-v3.5',
query='What is the capital of the United States?',
documents=docs,
top_n=3,
)
print(response)
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
def test_datasets_crud(self) -> None:
my_dataset = co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
print(my_dataset)
my_datasets = co.datasets.list()
print(my_datasets)
dataset = co.datasets.get(my_dataset.id or "")
print(dataset)
co.datasets.delete(my_dataset.id or "")
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
def test_save_load(self) -> None:
my_dataset = co.datasets.create(
name="test",
type="embed-input",
data=open(embed_job, 'rb'),
)
result = co.wait(my_dataset)
co.utils.save_dataset(result.dataset, "dataset.jsonl")
# assert files equal
self.assertTrue(os.path.exists("dataset.jsonl"))
self.assertEqual(open(embed_job, 'rb').read(),
open("dataset.jsonl", 'rb').read())
print(result)
co.datasets.delete(my_dataset.id or "")
def test_tokenize(self) -> None:
response = co.tokenize(
text='tokenize me! :D',
model='command-a-03-2025',
offline=False,
)
print(response)
def test_detokenize(self) -> None:
response = co.detokenize(
tokens=[10104, 12221, 1315, 34, 1420, 69],
model="command-a-03-2025",
offline=False,
)
print(response)
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
def test_tool_use(self) -> None:
tools = [
Tool(
name="sales_database",
description="Connects to a database about sales volumes",
parameter_definitions={
"day": ToolParameterDefinitionsValue(
description="Retrieves sales data from this day, formatted as YYYY-MM-DD.",
type="str",
required=True
)}
)
]
tool_parameters_response = co.chat(
message="How good were the sales on September 29 2023?",
tools=tools,
model="command-nightly",
preamble="""
## Task Description
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
## Style Guide
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
"""
)
if tool_parameters_response.tool_calls is not None:
self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
else:
raise ValueError("Expected tool calls to be present")
local_tools = {
"sales_database": lambda day: {
"number_of_sales": 120,
"total_revenue": 48500,
"average_sale_value": 404.17,
"date": "2023-09-29"
}
}
tool_results = []
for tool_call in tool_parameters_response.tool_calls:
output = local_tools[tool_call.name](**tool_call.parameters)
outputs = [output]
tool_results.append(ToolResult(
call=tool_call,
outputs=outputs
))
cited_response = co.chat(
message="How good were the sales on September 29?",
tools=tools,
tool_results=tool_results,
force_single_step=True,
model="command-nightly",
)
self.assertEqual(cited_response.documents, [
{
"average_sale_value": "404.17",
"date": "2023-09-29",
"id": "sales_database:0:0",
"number_of_sales": "120",
"total_revenue": "48500",
}
])
def test_local_tokenize(self) -> None:
response = co.tokenize(
model="command-a-03-2025",
text="tokenize me! :D"
)
print(response)
def test_local_detokenize(self) -> None:
response = co.detokenize(
model="command-a-03-2025",
tokens=[10104, 12221, 1315, 34, 1420, 69]
)
print(response)
================================================
FILE: tests/test_client_init.py
================================================
import os
import typing
import unittest
import cohere
from cohere import ToolMessage, UserMessage, AssistantMessage
import importlib.util
HAS_BOTO3 = importlib.util.find_spec("boto3") is not None
class TestClientInit(unittest.TestCase):
@unittest.skipUnless(HAS_BOTO3, "boto3 not installed")
def test_aws_inits(self) -> None:
cohere.BedrockClient()
cohere.BedrockClientV2()
cohere.SagemakerClient()
cohere.SagemakerClientV2()
def test_inits(self) -> None:
cohere.Client(api_key="n/a")
cohere.ClientV2(api_key="n/a")
================================================
FILE: tests/test_client_v2.py
================================================
import os
import typing
import unittest
import cohere
from cohere import ToolMessage, UserMessage, AssistantMessage
co = cohere.ClientV2(timeout=10000)
package_dir = os.path.dirname(os.path.abspath(__file__))
embed_job = os.path.join(package_dir, "embed_job.jsonl")
class TestClientV2(unittest.TestCase):
def test_chat(self) -> None:
response = co.chat(
model="command-a-03-2025", messages=[cohere.UserChatMessageV2(content="hello world!")])
print(response.message)
def test_chat_stream(self) -> None:
stream = co.chat_stream(
model="command-a-03-2025", messages=[cohere.UserChatMessageV2(content="hello world!")])
events = set()
for chat_event in stream:
if chat_event is not None:
events.add(chat_event.type)
if chat_event.type == "content-delta":
print(chat_event.delta)
self.assertTrue("message-start" in events)
self.assertTrue("content-start" in events)
self.assertTrue("content-delta" in events)
self.assertTrue("content-end" in events)
self.assertTrue("message-end" in events)
def test_legacy_methods_available(self) -> None:
self.assertTrue(hasattr(co, "generate"))
self.assertTrue(callable(getattr(co, "generate")))
self.assertTrue(hasattr(co, "generate_stream"))
self.assertTrue(callable(getattr(co, "generate_stream")))
@unittest.skip("Skip v2 test for now")
def test_chat_documents(self) -> None:
from cohere import Document
documents = [
Document(data={"title": "widget sales 2019", "text": "1 million"}),
Document(data={"title": "widget sales 2020", "text": "2 million"}),
Document(data={"title": "widget sales 2021", "text": "4 million"}),
]
response = co.chat(
messages=[cohere.UserChatMessageV2(
content=[cohere.TextContent(text="how many widges were sold in 2020?")],
)],
model="command-a-03-2025",
documents=documents,
)
print(response.message)
@unittest.skip("Skip v2 test for now")
def test_chat_tools(self) -> None:
from typing import Sequence
get_weather_tool = cohere.ToolV2Function(
name="get_weather",
description="gets the weather of a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "the location to get weather, example: San Fransisco, CA",
}
},
"required": ["location"],
},
)
tools = [cohere.ToolV2(type="function", function=get_weather_tool)]
messages: cohere.ChatMessages = [
cohere.UserChatMessageV2(content="what is the weather in Toronto?")
]
res = co.chat(model="command-a-03-2025", tools=tools, messages=messages)
# call the get_weather tool
tool_result = {"temperature": "30C"}
tool_content: Sequence[cohere.TextToolContent] = [cohere.TextToolContent(text="The weather in Toronto is 30C")]
# Use the first text content from the response if available, else fallback to str
assistant_content = res.message.content[0].text if (hasattr(res.message, 'content') and isinstance(res.message.content, list) and len(res.message.content) > 0 and hasattr(res.message.content[0], 'text')) else str(res.message)
messages.append(cohere.AssistantChatMessageV2(content=[cohere.TextAssistantMessageV2ContentItem(text=assistant_content)]))
if res.message.tool_calls is not None and res.message.tool_calls[0].id is not None:
messages.append(cohere.ToolChatMessageV2(
tool_call_id=res.message.tool_calls[0].id, content=list(tool_content)))
res = co.chat(tools=tools, messages=messages, model="command-a-03-2025")
print(res.message)
================================================
FILE: tests/test_embed_streaming.py
================================================
"""Tests for memory-efficient embed_stream functionality.
All embed_stream code lives in manually maintained files (.fernignore protected):
- src/cohere/client.py — Client.embed_stream()
- src/cohere/manually_maintained/streaming_embed.py — StreamedEmbedding, extraction helpers
"""
import unittest
from cohere.manually_maintained.streaming_embed import (
StreamedEmbedding,
extract_embeddings_from_response,
)
from cohere.config import embed_stream_batch_size
class TestStreamedEmbedding(unittest.TestCase):
"""Test the StreamedEmbedding dataclass."""
def test_creation(self):
emb = StreamedEmbedding(index=0, embedding=[0.1, 0.2], embedding_type="float", text="hello")
self.assertEqual(emb.index, 0)
self.assertEqual(emb.embedding, [0.1, 0.2])
self.assertEqual(emb.embedding_type, "float")
self.assertEqual(emb.text, "hello")
def test_text_optional(self):
emb = StreamedEmbedding(index=0, embedding=[0.1], embedding_type="float")
self.assertIsNone(emb.text)
class TestExtractEmbeddings(unittest.TestCase):
"""Test extract_embeddings_from_response for V1 and V2 formats."""
def test_v1_embeddings_floats(self):
"""V1 embeddings_floats response returns flat float embeddings."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
}
results = list(extract_embeddings_from_response(response, ["hello", "world"]))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].index, 0)
self.assertEqual(results[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(results[0].embedding_type, "float")
self.assertEqual(results[0].text, "hello")
self.assertEqual(results[1].index, 1)
self.assertEqual(results[1].text, "world")
def test_v1_embeddings_by_type(self):
"""V1 embeddings_by_type response returns typed embeddings."""
response = {
"response_type": "embeddings_by_type",
"embeddings": {
"float_": [[0.1, 0.2], [0.3, 0.4]],
"int8": [[1, 2], [3, 4]],
},
}
results = list(extract_embeddings_from_response(response, ["a", "b"]))
# 2 texts * 2 types = 4 embeddings
self.assertEqual(len(results), 4)
float_results = [r for r in results if r.embedding_type == "float"]
int8_results = [r for r in results if r.embedding_type == "int8"]
self.assertEqual(len(float_results), 2)
self.assertEqual(len(int8_results), 2)
def test_v2_response_format(self):
"""V2 response (no response_type) returns dict embeddings."""
response = {
"embeddings": {
"float_": [[0.1, 0.2], [0.3, 0.4]],
},
}
results = list(extract_embeddings_from_response(response, ["x", "y"]))
self.assertEqual(len(results), 2)
self.assertEqual(results[0].embedding_type, "float")
self.assertEqual(results[0].text, "x")
def test_global_offset(self):
"""Global offset adjusts indices for batched processing."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1], [0.2]],
}
results = list(extract_embeddings_from_response(response, ["c", "d"], global_offset=100))
self.assertEqual(results[0].index, 100)
self.assertEqual(results[1].index, 101)
def test_empty_embeddings(self):
"""Empty response yields nothing."""
response = {"response_type": "embeddings_floats", "embeddings": []}
results = list(extract_embeddings_from_response(response, []))
self.assertEqual(results, [])
def test_texts_shorter_than_embeddings(self):
"""Text is None when batch_texts runs out."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1], [0.2], [0.3]],
}
results = list(extract_embeddings_from_response(response, ["only_one"]))
self.assertEqual(results[0].text, "only_one")
self.assertIsNone(results[1].text)
self.assertIsNone(results[2].text)
class TestBatchSizeConstant(unittest.TestCase):
"""Test that batch_size defaults come from config, not magic numbers."""
def test_default_batch_size_matches_api_limit(self):
self.assertEqual(embed_stream_batch_size, 96)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_embed_utils.py
================================================
import unittest
from cohere import EmbeddingsByTypeEmbedResponse, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \
ApiMetaApiVersion, EmbeddingsFloatsEmbedResponse
from cohere.utils import merge_embed_responses
ebt_1 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[0, 1, 2], [3, 4, 5]],
int8=[[0, 1, 2], [3, 4, 5]],
uint8=[[0, 1, 2], [3, 4, 5]],
binary=[[0, 1, 2], [3, 4, 5]],
ubinary=[[0, 1, 2], [3, 4, 5]],
),
texts=["hello", "goodbye"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=1,
output_tokens=1,
search_units=1,
classifications=1
),
warnings=["test_warning_1"]
)
)
ebt_2 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="2",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[7, 8, 9], [10, 11, 12]],
int8=[[7, 8, 9], [10, 11, 12]],
uint8=[[7, 8, 9], [10, 11, 12]],
binary=[[7, 8, 9], [10, 11, 12]],
ubinary=[[7, 8, 9], [10, 11, 12]],
),
texts=["bye", "seeya"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=2,
output_tokens=2,
search_units=2,
classifications=2
),
warnings=["test_warning_1", "test_warning_2"]
)
)
ebt_partial_1 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[0, 1, 2], [3, 4, 5]],
int8=[[0, 1, 2], [3, 4, 5]],
binary=[[5, 6, 7], [8, 9, 10]],
),
texts=["hello", "goodbye"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=1,
output_tokens=1,
search_units=1,
classifications=1
),
warnings=["test_warning_1"]
)
)
ebt_partial_2 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="2",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[7, 8, 9], [10, 11, 12]],
int8=[[7, 8, 9], [10, 11, 12]],
binary=[[14, 15, 16], [17, 18, 19]],
),
texts=["bye", "seeya"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=2,
output_tokens=2,
search_units=2,
classifications=2
),
warnings=["test_warning_1", "test_warning_2"]
)
)
ebf_1 = EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="1",
texts=["hello", "goodbye"],
embeddings=[[0, 1, 2], [3, 4, 5]],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=1,
output_tokens=1,
search_units=1,
classifications=1
),
warnings=["test_warning_1"]
)
)
ebf_2 = EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="2",
texts=["bye", "seeya"],
embeddings=[[7, 8, 9], [10, 11, 12]],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=2,
output_tokens=2,
search_units=2,
classifications=2
),
warnings=["test_warning_1", "test_warning_2"]
)
)
class TestClient(unittest.TestCase):
def test_merge_embeddings_by_type(self) -> None:
resp = merge_embed_responses([
ebt_1,
ebt_2
])
if resp.meta is None:
raise Exception("this is just for mpy")
self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1, 2",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
uint8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
binary=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
ubinary=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
),
texts=["hello", "goodbye", "bye", "seeya"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=3,
output_tokens=3,
search_units=3,
classifications=3
),
warnings=resp.meta.warnings # order ignored
)
))
def test_merge_embeddings_floats(self) -> None:
resp = merge_embed_responses([
ebf_1,
ebf_2
])
if resp.meta is None:
raise Exception("this is just for mpy")
self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="1, 2",
texts=["hello", "goodbye", "bye", "seeya"],
embeddings=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=3,
output_tokens=3,
search_units=3,
classifications=3
),
warnings=resp.meta.warnings # order ignored
)
))
def test_merge_partial_embeddings_floats(self) -> None:
resp = merge_embed_responses([
ebt_partial_1,
ebt_partial_2
])
if resp.meta is None:
raise Exception("this is just for mpy")
self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1, 2",
embeddings=EmbedByTypeResponseEmbeddings(
float_=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
int8=[[0, 1, 2], [3, 4, 5], [7, 8, 9], [10, 11, 12]],
binary=[[5, 6, 7], [8, 9, 10], [14, 15, 16], [17, 18, 19]],
),
texts=["hello", "goodbye", "bye", "seeya"],
meta=ApiMeta(
api_version=ApiMetaApiVersion(version="1"),
billed_units=ApiMetaBilledUnits(
input_tokens=3,
output_tokens=3,
search_units=3,
classifications=3
),
warnings=resp.meta.warnings # order ignored
)
))
================================================
FILE: tests/test_oci_client.py
================================================
"""Integration and unit tests for OCI Generative AI client.
All integration tests are validated against the live OCI Generative AI inference
layer (us-chicago-1). The OciClientV2 uses the V2 Cohere API format (COHEREV2)
and communicates with the OCI inference endpoint at:
https://inference.generativeai.{region}.oci.oraclecloud.com
Integration test coverage:
V1 API (OciClient — Command R family):
Test Model What it proves
------------------------------- -------------------------- ------------------------------------------
test_embed embed-english-v3.0 V1 embed returns 2x 1024-dim float vectors
test_chat command-r-08-2024 V1 chat returns text with COHERE apiFormat
test_chat_stream command-r-08-2024 V1 streaming with text-generation events
V2 API (OciClientV2 — Command A family):
Test Model What it proves
------------------------------- -------------------------- ------------------------------------------
test_embed_v2 embed-english-v3.0 V2 embed returns dict with float_ key
test_embed_with_model_prefix_v2 cohere.embed-english-v3.0 Model normalization works
test_chat_v2 command-a-03-2025 V2 chat returns message with COHEREV2 format
test_chat_stream_v2 command-a-03-2025 V2 SSE streaming with content-delta events
test_command_a_chat command-a-03-2025 Command A chat via V2
Cross-cutting:
Test Model What it proves
------------------------------- -------------------------- ------------------------------------------
test_config_file_auth embed-english-v3.0 API key auth from config file
test_custom_profile_auth embed-english-v3.0 Custom OCI profile auth
test_embed_english_v3 embed-english-v3.0 1024-dim embeddings
test_embed_multilingual_v3 embed-multilingual-v3.0 Multilingual model works
test_invalid_model invalid-model-name Error handling works
test_missing_compartment_id -- Raises TypeError
Requirements:
1. OCI SDK installed: pip install oci
2. OCI credentials configured in ~/.oci/config
3. TEST_OCI environment variable set to run
4. OCI_COMPARTMENT_ID environment variable with valid OCI compartment OCID
5. OCI_REGION environment variable (optional, defaults to us-chicago-1)
Run with:
TEST_OCI=1 OCI_COMPARTMENT_ID=ocid1.compartment.oc1... pytest tests/test_oci_client.py
"""
import os
import sys
import tempfile
import types
import unittest
from unittest.mock import MagicMock, mock_open, patch
import cohere
if "tokenizers" not in sys.modules:
tokenizers_stub = types.ModuleType("tokenizers")
tokenizers_stub.Tokenizer = object # type: ignore[attr-defined]
sys.modules["tokenizers"] = tokenizers_stub
if "fastavro" not in sys.modules:
fastavro_stub = types.ModuleType("fastavro")
fastavro_stub.parse_schema = lambda schema: schema # type: ignore[attr-defined]
fastavro_stub.reader = lambda *args, **kwargs: iter(()) # type: ignore[attr-defined]
fastavro_stub.writer = lambda *args, **kwargs: None # type: ignore[attr-defined]
sys.modules["fastavro"] = fastavro_stub
if "httpx_sse" not in sys.modules:
httpx_sse_stub = types.ModuleType("httpx_sse")
httpx_sse_stub.connect_sse = lambda *args, **kwargs: None # type: ignore[attr-defined]
sys.modules["httpx_sse"] = httpx_sse_stub
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
class TestOciClient(unittest.TestCase):
"""Test OciClient (V1 API) with OCI Generative AI."""
def setUp(self):
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
region = os.getenv("OCI_REGION", "us-chicago-1")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
self.client = cohere.OciClient(
oci_region=region,
oci_compartment_id=compartment_id,
oci_profile=profile,
)
def test_embed(self):
"""Test embedding with V1 client."""
response = self.client.embed(
model="embed-english-v3.0",
texts=["Hello world", "Cohere on OCI"],
input_type="search_document",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
self.assertEqual(len(response.embeddings), 2)
self.assertEqual(len(response.embeddings[0]), 1024)
self.assertEqual(response.response_type, "embeddings_floats")
def test_chat(self):
"""Test V1 chat with Command R."""
response = self.client.chat(
model="command-r-08-2024",
message="What is 2+2? Answer with just the number.",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.text)
self.assertIn("4", response.text)
def test_chat_stream(self):
"""Test V1 streaming chat terminates and produces correct events."""
events = []
for event in self.client.chat_stream(
model="command-r-08-2024",
message="Count from 1 to 3.",
):
events.append(event)
self.assertTrue(len(events) > 0)
text_events = [e for e in events if hasattr(e, "text") and e.text]
self.assertTrue(len(text_events) > 0)
# Verify stream terminates with correct event lifecycle
event_types = [getattr(e, "event_type", None) for e in events]
self.assertEqual(event_types[0], "stream-start")
self.assertEqual(event_types[-1], "stream-end")
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
class TestOciClientV2(unittest.TestCase):
"""Test OciClientV2 (v2 API) with OCI Generative AI."""
def setUp(self):
"""Set up OCI v2 client for each test."""
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
region = os.getenv("OCI_REGION", "us-chicago-1")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
self.client = cohere.OciClientV2(
oci_region=region,
oci_compartment_id=compartment_id,
oci_profile=profile,
)
def test_embed_v2(self):
"""Test embedding with v2 client."""
response = self.client.embed(
model="embed-english-v3.0",
texts=["Hello from v2", "Second text"],
input_type="search_document",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
# V2 returns embeddings as a dict with "float" key
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_), 2)
# Verify embedding dimensions (1024 for embed-english-v3.0)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
self.assertEqual(response.response_type, "embeddings_by_type")
def test_embed_with_model_prefix_v2(self):
"""Test embedding with 'cohere.' model prefix on v2 client."""
response = self.client.embed(
model="cohere.embed-english-v3.0",
texts=["Test with prefix"],
input_type="search_document",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_), 1)
def test_chat_v2(self):
"""Test chat with v2 client."""
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "Say hello"}],
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.message)
def test_chat_vision_v2(self):
"""Test vision with inline image on Command A Vision."""
import base64, struct, zlib
# Create a minimal 1x1 red PNG
raw = b'\x00\xff\x00\x00'
compressed = zlib.compress(raw)
def chunk(ctype, data):
c = ctype + data
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
ihdr = struct.pack('>IIBBBBB', 1, 1, 8, 2, 0, 0, 0)
png = b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', compressed) + chunk(b'IEND', b'')
img_b64 = base64.b64encode(png).decode()
response = self.client.chat(
model="command-a-vision",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "What color is this image? Reply with one word."},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
],
}],
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.message)
self.assertTrue(len(response.message.content) > 0)
# The 1x1 red pixel should be identified as red
self.assertIn("red", response.message.content[0].text.lower())
def test_chat_tool_use_v2(self):
"""Test tool use with v2 client on OCI on-demand inference."""
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
max_tokens=200,
tools=[{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
},
},
}],
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.message)
self.assertEqual(response.finish_reason, "TOOL_CALL")
self.assertTrue(len(response.message.tool_calls) > 0)
tool_call = response.message.tool_calls[0]
self.assertEqual(tool_call.function.name, "get_weather")
self.assertIn("Toronto", tool_call.function.arguments)
def test_chat_tool_use_response_type_lowered(self):
"""Test that tool_call type is lowercased in response (OCI returns FUNCTION)."""
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
max_tokens=200,
tools=[{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
},
},
}],
)
self.assertEqual(response.finish_reason, "TOOL_CALL")
tool_call = response.message.tool_calls[0]
# OCI returns "FUNCTION" — SDK must lowercase to "function" for Cohere compat
self.assertEqual(tool_call.type, "function")
def test_chat_multi_turn_tool_use_v2(self):
"""Test multi-turn tool use: send tool result back after tool call."""
# Step 1: Get a tool call
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
max_tokens=200,
tools=[{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
},
},
}],
)
self.assertEqual(response.finish_reason, "TOOL_CALL")
tool_call = response.message.tool_calls[0]
# Step 2: Send tool result back
response2 = self.client.chat(
model="command-a-03-2025",
messages=[
{"role": "user", "content": "What's the weather in Toronto?"},
{
"role": "assistant",
"tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": "get_weather", "arguments": tool_call.function.arguments}}],
"tool_plan": response.message.tool_plan,
},
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": [{"type": "text", "text": "15°C, sunny"}],
},
],
max_tokens=200,
tools=[{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
},
},
}],
)
self.assertIsNotNone(response2.message)
# Model should respond with text incorporating the tool result
self.assertTrue(len(response2.message.content) > 0)
def test_chat_safety_mode_v2(self):
"""Test that safety_mode is uppercased for OCI."""
# Cohere SDK enum values are already uppercase, but test lowercase too
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "Say hi"}],
safety_mode="STRICT",
)
self.assertIsNotNone(response.message)
def test_chat_stream_v2(self):
"""Test V2 streaming chat terminates and produces correct event lifecycle."""
events = []
for event in self.client.chat_stream(
model="command-a-03-2025",
messages=[{"role": "user", "content": "Count from 1 to 3"}],
):
events.append(event)
self.assertTrue(len(events) > 0)
# Verify full event lifecycle: message-start → content-start → content-delta(s) → content-end → message-end
event_types = [e.type for e in events]
self.assertEqual(event_types[0], "message-start")
self.assertIn("content-start", event_types)
self.assertIn("content-delta", event_types)
self.assertIn("content-end", event_types)
self.assertEqual(event_types[-1], "message-end")
# Verify we can extract text from content-delta events
full_text = ""
for event in events:
if (
hasattr(event, "delta")
and event.delta
and hasattr(event.delta, "message")
and event.delta.message
and hasattr(event.delta.message, "content")
and event.delta.message.content
and hasattr(event.delta.message.content, "text")
and event.delta.message.content.text is not None
):
full_text += event.delta.message.content.text
self.assertTrue(len(full_text) > 0)
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
class TestOciClientAuthentication(unittest.TestCase):
"""Test different OCI authentication methods."""
def test_config_file_auth(self):
"""Test authentication using OCI config file."""
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
client = cohere.OciClientV2(
oci_region="us-chicago-1",
oci_compartment_id=compartment_id,
oci_profile=profile,
)
# Test with a simple embed call
response = client.embed(
model="embed-english-v3.0",
texts=["Auth test"],
input_type="search_document",
)
self.assertIsNotNone(response)
self.assertIsNotNone(response.embeddings)
def test_custom_profile_auth(self):
"""Test authentication using custom OCI profile."""
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
client = cohere.OciClientV2(
oci_profile=profile,
oci_region="us-chicago-1",
oci_compartment_id=compartment_id,
)
response = client.embed(
model="embed-english-v3.0",
texts=["Profile auth test"],
input_type="search_document",
)
self.assertIsNotNone(response)
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
class TestOciClientErrors(unittest.TestCase):
"""Test error handling in OCI client."""
def test_missing_compartment_id(self):
"""Test error when compartment ID is missing."""
with self.assertRaises(TypeError):
cohere.OciClientV2(
oci_region="us-chicago-1",
# Missing oci_compartment_id
)
def test_invalid_model(self):
"""Test error handling with invalid model."""
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
client = cohere.OciClientV2(
oci_region="us-chicago-1",
oci_compartment_id=compartment_id,
oci_profile=profile,
)
# OCI should return an error for invalid model
with self.assertRaises(Exception):
client.embed(
model="invalid-model-name",
texts=["Test"],
input_type="search_document",
)
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
class TestOciClientModels(unittest.TestCase):
"""Test different Cohere models on OCI."""
def setUp(self):
"""Set up OCI client for each test."""
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not compartment_id:
self.skipTest("OCI_COMPARTMENT_ID not set")
region = os.getenv("OCI_REGION", "us-chicago-1")
profile = os.getenv("OCI_PROFILE", "DEFAULT")
self.client = cohere.OciClientV2(
oci_region=region,
oci_compartment_id=compartment_id,
oci_profile=profile,
)
def test_embed_english_v3(self):
"""Test embed-english-v3.0 model."""
response = self.client.embed(
model="embed-english-v3.0",
texts=["Test"],
input_type="search_document",
)
self.assertIsNotNone(response.embeddings)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
def test_embed_multilingual_v3(self):
"""Test embed-multilingual-v3.0 model."""
response = self.client.embed(
model="embed-multilingual-v3.0",
texts=["Test"],
input_type="search_document",
)
self.assertIsNotNone(response.embeddings)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
def test_command_a_chat(self):
"""Test command-a-03-2025 model for chat."""
response = self.client.chat(
model="command-a-03-2025",
messages=[{"role": "user", "content": "Hello"}],
)
self.assertIsNotNone(response.message)
def test_embed_english_light_v3(self):
"""Test embed-english-light-v3.0 returns 384-dim vectors."""
response = self.client.embed(
model="embed-english-light-v3.0",
texts=["Hello world"],
input_type="search_document",
)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 384)
def test_embed_multilingual_light_v3(self):
"""Test embed-multilingual-light-v3.0 returns 384-dim vectors."""
response = self.client.embed(
model="embed-multilingual-light-v3.0",
texts=["Bonjour le monde"],
input_type="search_document",
)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 384)
def test_embed_search_query_input_type(self):
"""Test embed with search_query input_type (distinct from search_document)."""
response = self.client.embed(
model="embed-english-v3.0",
texts=["What is the capital of France?"],
input_type="search_query",
)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
def test_embed_with_embedding_types(self):
"""Test embed with explicit embedding_types parameter."""
response = self.client.embed(
model="embed-english-v3.0",
texts=["Hello world"],
input_type="search_document",
embedding_types=["float"],
)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
def test_embed_with_truncate(self):
"""Test embed with truncate parameter."""
long_text = "hello " * 1000
for mode in ["NONE", "START", "END"]:
response = self.client.embed(
model="embed-english-v3.0",
texts=[long_text],
input_type="search_document",
truncate=mode,
)
self.assertIsNotNone(response.embeddings.float_)
self.assertEqual(len(response.embeddings.float_[0]), 1024)
def test_command_r_plus_chat(self):
"""Test command-r-plus-08-2024 via V1 client."""
v1_client = cohere.OciClient(
oci_region=os.getenv("OCI_REGION", "us-chicago-1"),
oci_compartment_id=os.getenv("OCI_COMPARTMENT_ID"),
oci_profile=os.getenv("OCI_PROFILE", "DEFAULT"),
)
response = v1_client.chat(
model="command-r-plus-08-2024",
message="What is 2+2? Answer with just the number.",
)
self.assertIsNotNone(response.text)
self.assertIn("4", response.text)
def test_v2_multi_turn_chat(self):
"""Test V2 chat with conversation history (multi-turn)."""
response = self.client.chat(
model="command-a-03-2025",
messages=[
{"role": "user", "content": "My name is Alice."},
{"role": "assistant", "content": "Nice to meet you, Alice!"},
{"role": "user", "content": "What is my name?"},
],
)
self.assertIsNotNone(response.message)
content = response.message.content[0].text
self.assertIn("Alice", content)
def test_v2_system_message(self):
"""Test V2 chat with a system message."""
response = self.client.chat(
model="command-a-03-2025",
messages=[
{"role": "system", "content": "You are a helpful assistant. Always respond in exactly 3 words."},
{"role": "user", "content": "Say hello."},
],
)
self.assertIsNotNone(response.message)
self.assertIsNotNone(response.message.content[0].text)
class TestOciClientTransformations(unittest.TestCase):
"""Unit tests for OCI request/response transformations (no OCI credentials required)."""
def test_thinking_parameter_transformation(self):
"""Test that thinking parameter is correctly transformed to OCI format."""
from cohere.oci_client import transform_request_to_oci
cohere_body = {
"model": "command-a-reasoning-08-2025",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"thinking": {
"type": "enabled",
"token_budget": 10000,
},
}
result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True)
# Verify thinking parameter is transformed with camelCase for OCI API
chat_request = result["chatRequest"]
self.assertIn("thinking", chat_request)
self.assertEqual(chat_request["thinking"]["type"], "ENABLED")
self.assertEqual(chat_request["thinking"]["tokenBudget"], 10000) # camelCase for OCI
def test_thinking_parameter_disabled(self):
"""Test that disabled thinking is correctly transformed."""
from cohere.oci_client import transform_request_to_oci
cohere_body = {
"model": "command-a-reasoning-08-2025",
"messages": [{"role": "user", "content": "Hello"}],
"thinking": {
"type": "disabled",
},
}
result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True)
chat_request = result["chatRequest"]
self.assertIn("thinking", chat_request)
self.assertEqual(chat_request["thinking"]["type"], "DISABLED")
self.assertNotIn("token_budget", chat_request["thinking"])
def test_thinking_response_transformation(self):
"""Test that thinking content in response is correctly transformed."""
from cohere.oci_client import transform_oci_response_to_cohere
oci_response = {
"chatResponse": {
"id": "test-id",
"message": {
"role": "ASSISTANT",
"content": [
{"type": "THINKING", "thinking": "Let me think about this..."},
{"type": "TEXT", "text": "The answer is 4."},
],
},
"finishReason": "COMPLETE",
"usage": {"inputTokens": 10, "completionTokens": 20},
}
}
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
# Verify content types are lowercased
self.assertEqual(result["message"]["content"][0]["type"], "thinking")
self.assertEqual(result["message"]["content"][1]["type"], "text")
def test_stream_event_thinking_transformation(self):
"""Test that thinking content in stream events is correctly transformed."""
from cohere.oci_client import transform_stream_event
# OCI thinking event
oci_event = {
"message": {
"content": [{"type": "THINKING", "thinking": "Reasoning step..."}]
}
}
result = transform_stream_event("chat", oci_event, is_v2=True)
self.assertEqual(result[0]["type"], "content-delta")
self.assertIn("thinking", result[0]["delta"]["message"]["content"])
self.assertEqual(result[0]["delta"]["message"]["content"]["thinking"], "Reasoning step...")
def test_stream_event_text_transformation(self):
"""Test that text content in stream events is correctly transformed."""
from cohere.oci_client import transform_stream_event
# OCI text event
oci_event = {
"message": {
"content": [{"type": "TEXT", "text": "The answer is..."}]
}
}
result = transform_stream_event("chat", oci_event, is_v2=True)
self.assertEqual(result[0]["type"], "content-delta")
self.assertIn("text", result[0]["delta"]["message"]["content"])
self.assertEqual(result[0]["delta"]["message"]["content"]["text"], "The answer is...")
def test_thinking_parameter_none(self):
"""Test that thinking=None does not crash (issue: null guard)."""
from cohere.oci_client import transform_request_to_oci
cohere_body = {
"model": "command-a-03-2025",
"messages": [{"role": "user", "content": "Hello"}],
"thinking": None, # Explicitly set to None
}
# Should not crash with TypeError
result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True)
chat_request = result["chatRequest"]
# thinking should not be in request when None
self.assertNotIn("thinking", chat_request)
def test_v2_response_role_lowercased(self):
"""Test that V2 response message role is lowercased."""
from cohere.oci_client import transform_oci_response_to_cohere
oci_response = {
"chatResponse": {
"id": "test-id",
"message": {
"role": "ASSISTANT",
"content": [{"type": "TEXT", "text": "Hello"}],
},
"finishReason": "COMPLETE",
"usage": {"inputTokens": 10, "completionTokens": 20},
}
}
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
# Role should be lowercased
self.assertEqual(result["message"]["role"], "assistant")
def test_v2_response_finish_reason_uppercase(self):
"""Test that V2 response finish_reason stays uppercase."""
from cohere.oci_client import transform_oci_response_to_cohere
oci_response = {
"chatResponse": {
"id": "test-id",
"message": {
"role": "ASSISTANT",
"content": [{"type": "TEXT", "text": "Hello"}],
},
"finishReason": "MAX_TOKENS",
"usage": {"inputTokens": 10, "completionTokens": 20},
}
}
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
# V2 finish_reason should stay uppercase
self.assertEqual(result["finish_reason"], "MAX_TOKENS")
def test_v2_response_tool_calls_conversion(self):
"""Test that V2 response converts toolCalls to tool_calls."""
from cohere.oci_client import transform_oci_response_to_cohere
oci_response = {
"chatResponse": {
"id": "test-id",
"message": {
"role": "ASSISTANT",
"content": [{"type": "TEXT", "text": "I'll help with that."}],
"toolCalls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"city": "London"}'},
}
],
},
"finishReason": "TOOL_CALL",
"usage": {"inputTokens": 10, "completionTokens": 20},
}
}
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
# toolCalls should be converted to tool_calls
self.assertIn("tool_calls", result["message"])
self.assertNotIn("toolCalls", result["message"])
self.assertEqual(len(result["message"]["tool_calls"]), 1)
self.assertEqual(result["message"]["tool_calls"][0]["id"], "call_123")
def test_normalize_model_for_oci(self):
"""Test model name normalization for OCI."""
from cohere.oci_client import normalize_model_for_oci
# Plain model name gets cohere. prefix
self.assertEqual(normalize_model_for_oci("command-a-03-2025"), "cohere.command-a-03-2025")
# Already prefixed passes through
self.assertEqual(normalize_model_for_oci("cohere.embed-english-v3.0"), "cohere.embed-english-v3.0")
# OCID passes through
self.assertEqual(
normalize_model_for_oci("ocid1.generativeaimodel.oc1.us-chicago-1.abc"),
"ocid1.generativeaimodel.oc1.us-chicago-1.abc",
)
def test_transform_embed_request(self):
"""Test embed request transformation to OCI format."""
from cohere.oci_client import transform_request_to_oci
body = {
"model": "embed-english-v3.0",
"texts": ["hello", "world"],
"input_type": "search_document",
"truncate": "end",
"embedding_types": ["float", "int8"],
}
result = transform_request_to_oci("embed", body, "compartment-123")
self.assertEqual(result["inputs"], ["hello", "world"])
self.assertEqual(result["inputType"], "SEARCH_DOCUMENT")
self.assertEqual(result["truncate"], "END")
self.assertEqual(result["embeddingTypes"], ["float", "int8"])
self.assertEqual(result["compartmentId"], "compartment-123")
self.assertEqual(result["servingMode"]["modelId"], "cohere.embed-english-v3.0")
def test_transform_embed_request_with_optional_params(self):
"""Test embed request forwards optional params."""
from cohere.oci_client import transform_request_to_oci
body = {
"model": "embed-english-v3.0",
"inputs": [{"content": [{"type": "text", "text": "hello"}]}],
"input_type": "classification",
"max_tokens": 256,
"output_dimension": 512,
"priority": 42,
}
result = transform_request_to_oci("embed", body, "compartment-123")
self.assertEqual(result["inputs"], body["inputs"])
self.assertEqual(result["maxTokens"], 256)
self.assertEqual(result["outputDimension"], 512)
self.assertEqual(result["priority"], 42)
def test_transform_embed_request_rejects_images(self):
"""Test embed request fails clearly for unsupported top-level images."""
from cohere.oci_client import transform_request_to_oci
with self.assertRaises(ValueError) as ctx:
transform_request_to_oci(
"embed",
{
"model": "embed-english-v3.0",
"images": ["data:image/png;base64,abc"],
"input_type": "classification",
},
"compartment-123",
)
self.assertIn("top-level 'images' parameter", str(ctx.exception))
def test_transform_chat_request_optional_params(self):
"""Test chat request transformation includes optional params."""
from cohere.oci_client import transform_request_to_oci
body = {
"model": "command-a-03-2025",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 100,
"temperature": 0.7,
"stop_sequences": ["END"],
"frequency_penalty": 0.5,
"strict_tools": True,
"response_format": {"type": "json_object"},
"logprobs": True,
"tool_choice": "REQUIRED",
"priority": 7,
}
result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True)
chat_req = result["chatRequest"]
self.assertEqual(chat_req["maxTokens"], 100)
self.assertEqual(chat_req["temperature"], 0.7)
self.assertEqual(chat_req["stopSequences"], ["END"])
self.assertEqual(chat_req["frequencyPenalty"], 0.5)
self.assertTrue(chat_req["strictTools"])
self.assertEqual(chat_req["responseFormat"], {"type": "json_object"})
self.assertTrue(chat_req["logprobs"])
self.assertEqual(chat_req["toolChoice"], "REQUIRED")
self.assertEqual(chat_req["priority"], 7)
def test_v2_client_rejects_v1_request(self):
"""Test OciClientV2 fails when given V1-style 'message' string."""
from cohere.oci_client import transform_request_to_oci
with self.assertRaises(ValueError) as ctx:
transform_request_to_oci(
"chat",
{"model": "command-a-03-2025", "message": "Hello"},
"compartment-123",
is_v2=True,
)
self.assertIn("OciClientV2", str(ctx.exception))
def test_v1_client_rejects_v2_request(self):
"""Test OciClient fails when given V2-style 'messages' array."""
from cohere.oci_client import transform_request_to_oci
with self.assertRaises(ValueError) as ctx:
transform_request_to_oci(
"chat",
{"model": "command-r-08-2024", "messages": [{"role": "user", "content": "Hi"}]},
"compartment-123",
is_v2=False,
)
self.assertIn("OciClient ", str(ctx.exception))
def test_unsupported_endpoint_raises(self):
"""Test that transform_request_to_oci raises for unsupported endpoints."""
from cohere.oci_client import transform_request_to_oci
with self.assertRaises(ValueError) as ctx:
transform_request_to_oci("rerank", {"model": "rerank-v3.5"}, "compartment-123")
self.assertIn("rerank", str(ctx.exception))
self.assertIn("not supported", str(ctx.exception))
def test_v1_chat_request_optional_params(self):
"""Test V1 chat request forwards supported optional params."""
from cohere.oci_client import transform_request_to_oci
body = {
"model": "command-r-08-2024",
"message": "Hi",
"max_tokens": 100,
"temperature": 0.7,
"k": 10,
"p": 0.8,
"seed": 123,
"stop_sequences": ["END"],
"frequency_penalty": 0.5,
"presence_penalty": 0.2,
"documents": [{"title": "Doc", "text": "Body"}],
"tools": [{"name": "lookup"}],
"tool_results": [{"call": {"name": "lookup"}}],
"response_format": {"type": "json_object"},
"safety_mode": "NONE",
"priority": 4,
}
result = transform_request_to_oci("chat", body, "compartment-123", is_v2=False)
chat_req = result["chatRequest"]
self.assertEqual(chat_req["apiFormat"], "COHERE")
self.assertEqual(chat_req["message"], "Hi")
self.assertEqual(chat_req["maxTokens"], 100)
self.assertEqual(chat_req["temperature"], 0.7)
self.assertEqual(chat_req["topK"], 10)
self.assertEqual(chat_req["topP"], 0.8)
self.assertEqual(chat_req["seed"], 123)
self.assertEqual(chat_req["frequencyPenalty"], 0.5)
self.assertEqual(chat_req["presencePenalty"], 0.2)
self.assertEqual(chat_req["priority"], 4)
def test_v1_stream_wrapper_preserves_finish_reason(self):
"""Test V1 stream-end uses the OCI finish reason from the final event."""
import json
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: {"text": "Hello", "isFinished": false}\n',
b'data: {"text": " world", "isFinished": true, "finishReason": "MAX_TOKENS"}\n',
b"data: [DONE]\n",
]
events = [
json.loads(raw.decode("utf-8"))
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=False)
]
# First event should be stream-start with generation_id
self.assertEqual(events[0]["event_type"], "stream-start")
self.assertIn("generation_id", events[0])
self.assertEqual(events[3]["event_type"], "stream-end")
self.assertEqual(events[3]["finish_reason"], "MAX_TOKENS")
self.assertEqual(events[3]["response"]["text"], "Hello world")
def test_transform_chat_request_tool_message_fields(self):
"""Test tool message fields are converted to OCI names."""
from cohere.oci_client import transform_request_to_oci
body = {
"model": "command-a-03-2025",
"messages": [
{
"role": "assistant",
"content": [{"type": "text", "text": "Use tool"}],
"tool_calls": [{"id": "call_1"}],
"tool_plan": "Plan",
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [{"type": "text", "text": "Result"}],
},
],
}
result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True)
assistant_message, tool_message = result["chatRequest"]["messages"]
self.assertEqual(assistant_message["toolCalls"], [{"id": "call_1"}])
self.assertEqual(assistant_message["toolPlan"], "Plan")
self.assertEqual(tool_message["toolCallId"], "call_1")
def test_get_oci_url_known_endpoints(self):
"""Test URL generation for known endpoints."""
from cohere.oci_client import get_oci_url
url = get_oci_url("us-chicago-1", "embed")
self.assertIn("/actions/embedText", url)
url = get_oci_url("us-chicago-1", "chat")
self.assertIn("/actions/chat", url)
def test_get_oci_url_unknown_endpoint_raises(self):
"""Test that unknown endpoints raise ValueError instead of producing bad URLs."""
from cohere.oci_client import get_oci_url
with self.assertRaises(ValueError) as ctx:
get_oci_url("us-chicago-1", "unknown_endpoint")
self.assertIn("not supported", str(ctx.exception))
def test_load_oci_config_missing_private_key_raises(self):
"""Test that direct credentials without private key raises clear error."""
from cohere.oci_client import _load_oci_config
with patch("cohere.oci_client.lazy_oci", return_value=MagicMock()):
with self.assertRaises(ValueError) as ctx:
_load_oci_config(
auth_type="api_key",
config_path=None,
profile=None,
user_id="ocid1.user.oc1...",
fingerprint="xx:xx:xx",
tenancy_id="ocid1.tenancy.oc1...",
# No private_key_path or private_key_content
)
self.assertIn("oci_private_key_path", str(ctx.exception))
def test_load_oci_config_ignores_inherited_session_auth(self):
"""Test that named API-key profiles do not inherit DEFAULT session auth fields."""
from cohere.oci_client import _load_oci_config
config_text = """
[DEFAULT]
security_token_file=/tmp/default-token
[API_KEY_AUTH]
user=ocid1.user.oc1..test
fingerprint=aa:bb
key_file=/tmp/test.pem
tenancy=ocid1.tenancy.oc1..test
region=us-chicago-1
""".strip()
with tempfile.NamedTemporaryFile("w", delete=False) as config_file:
config_file.write(config_text)
config_path = config_file.name
try:
mock_oci = MagicMock()
mock_oci.config.from_file.return_value = {
"user": "ocid1.user.oc1..test",
"fingerprint": "aa:bb",
"key_file": "/tmp/test.pem",
"tenancy": "ocid1.tenancy.oc1..test",
"region": "us-chicago-1",
"security_token_file": "/tmp/default-token",
}
with patch("cohere.oci_client.lazy_oci", return_value=mock_oci):
config = _load_oci_config(
auth_type="api_key",
config_path=config_path,
profile="API_KEY_AUTH",
)
finally:
os.unlink(config_path)
self.assertNotIn("security_token_file", config)
def test_session_auth_prefers_security_token_signer(self):
"""Test session-based auth uses SecurityTokenSigner before API key signer."""
from cohere.oci_client import map_request_to_oci
mock_oci = MagicMock()
mock_security_signer = MagicMock()
mock_oci.signer.load_private_key_from_file.return_value = "private-key"
mock_oci.auth.signers.SecurityTokenSigner.return_value = mock_security_signer
with patch("cohere.oci_client.lazy_oci", return_value=mock_oci), patch(
"builtins.open", mock_open(read_data="session-token")
):
hook = map_request_to_oci(
oci_config={
"user": "ocid1.user.oc1..example",
"fingerprint": "xx:xx",
"tenancy": "ocid1.tenancy.oc1..example",
"security_token_file": "~/.oci/token",
"key_file": "~/.oci/key.pem",
},
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1..example",
)
request = MagicMock()
request.url.path = "/v2/embed"
request.read.return_value = b'{"model":"embed-english-v3.0","texts":["hello"]}'
request.method = "POST"
request.extensions = {}
hook(request)
# SecurityTokenSigner is called at least once (init) and again per request
# (token file is re-read on each signing call to pick up refreshed tokens).
mock_oci.auth.signers.SecurityTokenSigner.assert_called_with(
token="session-token",
private_key="private-key",
)
self.assertGreaterEqual(mock_oci.auth.signers.SecurityTokenSigner.call_count, 1)
mock_oci.signer.Signer.assert_not_called()
def test_session_token_refreshed_on_subsequent_requests(self):
"""Verify the refreshing signer picks up a new token written to the token file."""
import tempfile
import os
from cohere.oci_client import map_request_to_oci
mock_oci = MagicMock()
mock_oci.signer.load_private_key_from_file.return_value = "private-key"
# Write initial token to a real temp file so we can overwrite it later.
with tempfile.NamedTemporaryFile("w", suffix=".token", delete=False) as tf:
tf.write("token-v1")
token_path = tf.name
try:
with patch("cohere.oci_client.lazy_oci", return_value=mock_oci):
hook = map_request_to_oci(
oci_config={
"security_token_file": token_path,
"key_file": "/irrelevant.pem",
},
oci_region="us-chicago-1",
oci_compartment_id="ocid1.compartment.oc1..example",
)
def _make_request():
req = MagicMock()
req.url.path = "/v2/embed"
req.read.return_value = b'{"model":"embed-english-v3.0","texts":["hi"]}'
req.method = "POST"
req.extensions = {}
return req
# First request uses token-v1
hook(_make_request())
calls_after_first = mock_oci.auth.signers.SecurityTokenSigner.call_count
# Simulate token refresh by overwriting the file
with open(token_path, "w") as _f:
_f.write("token-v2")
# Second request — should re-read and use token-v2
hook(_make_request())
self.assertGreater(
mock_oci.auth.signers.SecurityTokenSigner.call_count,
calls_after_first,
"SecurityTokenSigner should be re-instantiated after token file update",
)
# Verify the latest call used the refreshed token
all_calls = mock_oci.auth.signers.SecurityTokenSigner.call_args_list
last_call = all_calls[-1]
last_token = last_call.kwargs.get("token") or (last_call.args[0] if last_call.args else None)
self.assertEqual(last_token, "token-v2", "Last signing call must use the refreshed token")
finally:
os.unlink(token_path)
def test_embed_response_lowercases_embedding_keys(self):
"""Test embed response uses lowercase keys expected by the SDK model."""
from cohere.oci_client import transform_oci_response_to_cohere
result = transform_oci_response_to_cohere(
"embed",
{
"id": "embed-id",
"embeddings": {"FLOAT": [[0.1, 0.2]], "INT8": [[1, 2]]},
"usage": {"inputTokens": 3, "completionTokens": 7},
},
is_v2=True,
)
self.assertIn("float", result["embeddings"])
self.assertIn("int8", result["embeddings"])
self.assertNotIn("FLOAT", result["embeddings"])
self.assertEqual(result["meta"]["tokens"]["output_tokens"], 7)
def test_embed_response_includes_response_type_v1(self):
"""Test V1 embed response includes response_type=embeddings_floats for SDK union."""
from cohere.oci_client import transform_oci_response_to_cohere
result = transform_oci_response_to_cohere(
"embed",
{
"id": "embed-id",
"embeddings": [[0.1, 0.2]],
"usage": {"inputTokens": 3, "completionTokens": 0},
},
is_v2=False,
)
self.assertEqual(result["response_type"], "embeddings_floats")
def test_embed_response_includes_response_type_v2(self):
"""Test V2 embed response includes response_type=embeddings_by_type for SDK union."""
from cohere.oci_client import transform_oci_response_to_cohere
result = transform_oci_response_to_cohere(
"embed",
{
"id": "embed-id",
"embeddings": {"FLOAT": [[0.1, 0.2]]},
"usage": {"inputTokens": 3, "completionTokens": 0},
},
is_v2=True,
)
self.assertEqual(result["response_type"], "embeddings_by_type")
def test_normalize_model_for_oci_rejects_empty_model(self):
"""Test model normalization fails clearly for empty model names."""
from cohere.oci_client import normalize_model_for_oci
with self.assertRaises(ValueError) as ctx:
normalize_model_for_oci("")
self.assertIn("non-empty model", str(ctx.exception))
def test_stream_wrapper_emits_full_event_lifecycle(self):
"""Test that stream emits message-start, content-start, content-delta, content-end, message-end."""
import json
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n',
b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n',
b'data: [DONE]\n',
]
events = []
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
line = raw.decode("utf-8").strip()
if line.startswith("data: "):
events.append(json.loads(line[6:]))
event_types = [e["type"] for e in events]
self.assertEqual(event_types[0], "message-start")
self.assertEqual(event_types[1], "content-start")
self.assertEqual(event_types[2], "content-delta")
self.assertEqual(event_types[3], "content-delta")
self.assertEqual(event_types[4], "content-end")
self.assertEqual(event_types[5], "message-end")
# Verify message-start has id and role
self.assertIn("id", events[0])
self.assertEqual(events[0]["delta"]["message"]["role"], "assistant")
# Verify content-start has index and type
self.assertEqual(events[1]["index"], 0)
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text")
self.assertEqual(events[5]["delta"]["finish_reason"], "COMPLETE")
def test_stream_wrapper_emits_new_content_block_on_thinking_transition(self):
"""Test streams emit a new content block when transitioning from thinking to text."""
import json
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n',
b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}, "finishReason": "COMPLETE"}\n',
b"data: [DONE]\n",
]
events = []
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
line = raw.decode("utf-8").strip()
if line.startswith("data: "):
events.append(json.loads(line[6:]))
self.assertEqual(events[1]["type"], "content-start")
self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "thinking")
self.assertEqual(events[2]["type"], "content-delta")
self.assertEqual(events[2]["index"], 0)
self.assertEqual(events[3], {"type": "content-end", "index": 0})
self.assertEqual(events[4]["type"], "content-start")
self.assertEqual(events[4]["index"], 1)
self.assertEqual(events[4]["delta"]["message"]["content"]["type"], "text")
self.assertEqual(events[5]["type"], "content-delta")
self.assertEqual(events[5]["index"], 1)
def test_stream_wrapper_no_spurious_block_on_finish_only_event(self):
"""Finish-only event after thinking block must not open a spurious empty text block."""
import json
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n',
b'data: {"finishReason": "COMPLETE"}\n',
b"data: [DONE]\n",
]
events = []
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
line = raw.decode("utf-8").strip()
if line.startswith("data: "):
events.append(json.loads(line[6:]))
types = [e["type"] for e in events]
# Must not contain two content-start events
self.assertEqual(types.count("content-start"), 1)
# The single content block must be thinking
cs = next(e for e in events if e["type"] == "content-start")
self.assertEqual(cs["delta"]["message"]["content"]["type"], "thinking")
# Must end cleanly
self.assertEqual(events[-1]["type"], "message-end")
def test_stream_wrapper_skips_malformed_json_with_warning(self):
"""Test that malformed JSON in SSE stream is skipped."""
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: not-valid-json\n',
b'data: {"message": {"content": [{"type": "TEXT", "text": "hello"}]}}\n',
b'data: [DONE]\n',
]
events = list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True))
# Should get message-start + content-start + content-delta + content-end + message-end.
self.assertEqual(len(events), 5)
def test_stream_wrapper_skips_message_end_for_empty_stream(self):
"""Test empty streams do not emit message-end without a preceding message-start."""
from cohere.oci_client import transform_oci_stream_wrapper
events = list(transform_oci_stream_wrapper(iter([b"data: [DONE]\n"]), "chat", is_v2=True))
self.assertEqual(events, [])
def test_stream_wrapper_done_uses_current_content_index_after_transition(self):
"""Test fallback content-end uses the latest content index after type transitions."""
import json
from cohere.oci_client import transform_oci_stream_wrapper
chunks = [
b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n',
b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}}\n',
b"data: [DONE]\n",
]
events = []
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
line = raw.decode("utf-8").strip()
if line.startswith("data: "):
events.append(json.loads(line[6:]))
self.assertEqual(events[-2], {"type": "content-end", "index": 1})
self.assertEqual(events[-1]["type"], "message-end")
def test_stream_wrapper_raises_on_transform_error(self):
"""Test that transform errors in stream produce OCI-specific error."""
from cohere.oci_client import transform_oci_stream_wrapper
# Event with structure that will cause transform_stream_event to fail
# (message is None, causing TypeError on "content" in None)
chunks = [
b'data: {"message": null}\n',
]
with self.assertRaises(RuntimeError) as ctx:
list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True))
self.assertIn("OCI stream event transformation failed", str(ctx.exception))
def test_stream_event_finish_reason_keeps_final_text(self):
"""Test finish events keep final text before content-end."""
from cohere.oci_client import transform_stream_event
events = transform_stream_event(
"chat",
{
"message": {"content": [{"type": "TEXT", "text": " world"}]},
"finishReason": "COMPLETE",
},
is_v2=True,
)
self.assertEqual(events[0]["type"], "content-delta")
self.assertEqual(events[0]["delta"]["message"]["content"]["text"], " world")
self.assertEqual(events[1]["type"], "content-end")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_oci_mypy.py
================================================
"""Mypy type-checking gate for OCI client code.
Runs mypy on OCI source and test files and fails if any type errors are found.
This prevents type regressions from being introduced silently.
Run with:
pytest tests/test_oci_mypy.py
"""
import os
import shutil
import subprocess
import unittest
MYPY_BIN = shutil.which("mypy")
# Files that must stay mypy-clean
OCI_SOURCE_FILES = [
"src/cohere/oci_client.py",
"src/cohere/manually_maintained/lazy_oci_deps.py",
]
OCI_TEST_FILES = [
"tests/test_oci_client.py",
]
# --follow-imports=silent prevents mypy from crawling into transitive
# dependencies (e.g. the AWS client) that have pre-existing errors.
_MYPY_BASE = [
"--config-file", "mypy.ini",
"--follow-imports=silent",
]
def _run_mypy(files: list[str], extra_env: dict[str, str] | None = None) -> tuple[int, str]:
"""Run mypy on the given files and return (exit_code, output)."""
assert MYPY_BIN is not None
env = {**os.environ, **(extra_env or {})}
result = subprocess.run(
[MYPY_BIN, *_MYPY_BASE, *files],
capture_output=True,
text=True,
env=env,
)
return result.returncode, (result.stdout + result.stderr).strip()
@unittest.skipIf(MYPY_BIN is None, "mypy not found on PATH")
class TestOciMypy(unittest.TestCase):
"""Ensure OCI files pass mypy with no new errors."""
def test_oci_source_types(self):
"""OCI source files must be free of mypy errors."""
code, output = _run_mypy(OCI_SOURCE_FILES)
self.assertEqual(code, 0, f"mypy found type errors in OCI source:\n{output}")
def test_oci_test_types(self):
"""OCI test files must be free of mypy errors."""
# PYTHONPATH=src so mypy can resolve `import cohere`
code, output = _run_mypy(OCI_TEST_FILES, extra_env={"PYTHONPATH": "src"})
self.assertEqual(code, 0, f"mypy found type errors in OCI tests:\n{output}")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_overrides.py
================================================
import unittest
from contextlib import redirect_stderr
import logging
from cohere import EmbedByTypeResponseEmbeddings
LOGGER = logging.getLogger(__name__)
class TestClient(unittest.TestCase):
def test_float_alias(self) -> None:
embeds = EmbedByTypeResponseEmbeddings(float_=[[1.0]])
self.assertEqual(embeds.float_, [[1.0]])
self.assertEqual(embeds.float, [[1.0]]) # type: ignore