Repository: andrewyng/aisuite Branch: main Commit: 695242a836a0 Files: 208 Total size: 1.0 MB Directory structure: gitextract_z7uqp5wm/ ├── .github/ │ └── workflows/ │ ├── black.yml │ └── run_pytest.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── aisuite/ │ ├── __init__.py │ ├── client.py │ ├── design-notes/ │ │ └── asr-parameter-design-motivation.md │ ├── framework/ │ │ ├── __init__.py │ │ ├── asr_params.py │ │ ├── chat_completion_response.py │ │ ├── choice.py │ │ ├── message.py │ │ ├── parameter_mapper.py │ │ └── provider_interface.py │ ├── mcp/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── config.py │ │ ├── schema_converter.py │ │ └── tool_wrapper.py │ ├── provider.py │ ├── providers/ │ │ ├── __init__.py │ │ ├── anthropic_provider.py │ │ ├── aws_provider.py │ │ ├── azure_provider.py │ │ ├── cerebras_provider.py │ │ ├── cohere_provider.py │ │ ├── deepgram_provider.py │ │ ├── deepseek_provider.py │ │ ├── fireworks_provider.py │ │ ├── google_provider.py │ │ ├── groq_provider.py │ │ ├── huggingface_provider.py │ │ ├── inception_provider.py │ │ ├── lmstudio_provider.py │ │ ├── message_converter.py │ │ ├── mistral_provider.py │ │ ├── nebius_provider.py │ │ ├── ollama_provider.py │ │ ├── openai_provider.py │ │ ├── sambanova_provider.py │ │ ├── together_provider.py │ │ ├── watsonx_provider.py │ │ └── xai_provider.py │ └── utils/ │ ├── tools.py │ └── utils.py ├── aisuite-js/ │ ├── README.md │ ├── examples/ │ │ ├── basic-usage.ts │ │ ├── chat-app/ │ │ │ ├── .eslintrc.cjs │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── index.html │ │ │ ├── package.json │ │ │ ├── postcss.config.js │ │ │ ├── src/ │ │ │ │ ├── App.tsx │ │ │ │ ├── components/ │ │ │ │ │ ├── ApiKeyModal.tsx │ │ │ │ │ ├── ChatContainer.tsx │ │ │ │ │ ├── ChatInput.tsx │ │ │ │ │ ├── ChatMessage.tsx │ │ │ │ │ ├── ModelSelector.tsx │ │ │ │ │ └── ProviderSelector.tsx │ │ │ │ ├── config/ │ │ │ │ │ └── llm-config.ts │ │ │ │ ├── index.css │ │ │ │ ├── main.tsx │ │ │ │ ├── services/ │ │ │ │ │ └── aisuite-service.ts │ │ │ │ ├── types/ │ │ │ │ │ └── chat.ts │ │ │ │ └── utils/ │ │ │ │ └── cn.ts │ │ │ ├── tailwind.config.js │ │ │ ├── tsconfig.json │ │ │ ├── tsconfig.node.json │ │ │ └── vite.config.ts │ │ ├── deepgram.ts │ │ ├── groq.ts │ │ ├── mistral.ts │ │ ├── openai-asr.ts │ │ ├── streaming.ts │ │ ├── test-suite.ts │ │ └── tool-calling.ts │ ├── jest.config.ts │ ├── package.json │ ├── src/ │ │ ├── asr-providers/ │ │ │ ├── deepgram/ │ │ │ │ ├── adapters.ts │ │ │ │ ├── index.ts │ │ │ │ ├── provider.ts │ │ │ │ └── types.ts │ │ │ └── index.ts │ │ ├── client.ts │ │ ├── core/ │ │ │ ├── base-asr-provider.ts │ │ │ ├── base-provider.ts │ │ │ ├── errors.ts │ │ │ └── model-parser.ts │ │ ├── index.ts │ │ ├── providers/ │ │ │ ├── anthropic/ │ │ │ │ ├── adapters.ts │ │ │ │ ├── index.ts │ │ │ │ ├── provider.ts │ │ │ │ └── types.ts │ │ │ ├── groq/ │ │ │ │ ├── adapters.ts │ │ │ │ ├── index.ts │ │ │ │ ├── provider.ts │ │ │ │ └── types.ts │ │ │ ├── index.ts │ │ │ ├── mistral/ │ │ │ │ ├── adapters.ts │ │ │ │ ├── index.ts │ │ │ │ ├── provider.ts │ │ │ │ └── types.ts │ │ │ └── openai/ │ │ │ ├── adapters.ts │ │ │ ├── index.ts │ │ │ ├── provider.ts │ │ │ └── types.ts │ │ ├── types/ │ │ │ ├── chat.ts │ │ │ ├── common.ts │ │ │ ├── index.ts │ │ │ ├── providers.ts │ │ │ ├── tools.ts │ │ │ └── transcription.ts │ │ └── utils/ │ │ └── streaming.ts │ ├── tests/ │ │ ├── client.test.ts │ │ ├── providers/ │ │ │ ├── anthropic-provider.test.ts │ │ │ ├── deepgram-provider.test.ts │ │ │ ├── groq-provider.test.ts │ │ │ ├── mistral-provider.test.ts │ │ │ ├── openai-provider.test.ts │ │ │ └── openai_asr_provider.test.ts │ │ └── utils/ │ │ └── streaming.test.ts │ └── tsconfig.json ├── examples/ │ ├── AISuiteDemo.ipynb │ ├── DeepseekPost.ipynb │ ├── QnA_with_pdf.ipynb │ ├── agents/ │ │ ├── movie_buff_assistant.ipynb │ │ ├── recipe_chef_assistant.ipynb │ │ ├── snake_game_generator.ipynb │ │ ├── stock_dashboard.html │ │ ├── stock_market_dashboard.html │ │ ├── stock_market_mini_tracker.ipynb │ │ ├── stock_market_tracker.ipynb │ │ └── world_weather_dashboard.ipynb │ ├── aisuite_tool_abstraction.ipynb │ ├── asr_example.ipynb │ ├── chat-ui/ │ │ ├── .streamlit/ │ │ │ └── config.toml │ │ ├── README.md │ │ ├── chat.py │ │ └── config.yaml │ ├── client.ipynb │ ├── llm_reasoning.ipynb │ ├── mcp_config_dict_example.py │ ├── mcp_http_example.py │ ├── mcp_tools_example.ipynb │ ├── simple_tool_calling.ipynb │ └── tool_calling_abstraction.ipynb ├── guides/ │ ├── README.md │ ├── anthropic.md │ ├── aws.md │ ├── azure.md │ ├── cerebras.md │ ├── cohere.md │ ├── deepseek.md │ ├── google.md │ ├── groq.md │ ├── huggingface.md │ ├── lmstudio.md │ ├── mistral.md │ ├── nebius.md │ ├── ollama.md │ ├── openai.md │ ├── sambanova.md │ ├── watsonx.md │ └── xai.md ├── pyproject.toml └── tests/ ├── __init__.py ├── client/ │ ├── __init__.py │ ├── test_client.py │ └── test_prerelease.py ├── framework/ │ ├── test_asr_models.py │ └── test_asr_params.py ├── mcp/ │ ├── README.md │ ├── __init__.py │ ├── conftest.py │ ├── test_client.py │ ├── test_e2e.py │ ├── test_http_llm_e2e.py │ ├── test_http_transport.py │ └── test_llm_e2e.py ├── providers/ │ ├── __init__.py │ ├── test_anthropic_converter.py │ ├── test_asr_parameter_passthrough.py │ ├── test_aws_converter.py │ ├── test_azure_provider.py │ ├── test_cerebras_provider.py │ ├── test_cohere_provider.py │ ├── test_deepgram_provider.py │ ├── test_deepseek_provider.py │ ├── test_google_converter.py │ ├── test_google_provider.py │ ├── test_groq_provider.py │ ├── test_huggingface_provider.py │ ├── test_inception_provider.py │ ├── test_lmstudio_provider.py │ ├── test_mistral_provider.py │ ├── test_nebius_provider.py │ ├── test_ollama_provider.py │ ├── test_openai_provider.py │ ├── test_sambanova_provider.py │ └── test_watsonx_provider.py ├── test_provider.py └── utils/ ├── test_mcp_memory_integration.py ├── test_tool_manager.py └── test_tools_mcp_schema.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/black.yml ================================================ name: Lint on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable ================================================ FILE: .github/workflows/run_pytest.yml ================================================ name: Lint on: [push, pull_request] jobs: build_and_test: runs-on: ubuntu-latest strategy: matrix: python-version: [ "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install poetry poetry install --all-extras --with test - name: Test with pytest run: poetry run pytest -m "not integration" ================================================ FILE: .gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST .env .venv env/ venv/ ENV/ *.whl # Node/TypeScript node_modules/ npm-debug.log* yarn-debug.log* yarn-error.log* .npm .env.local .env.*.local dist/ coverage/ *.tsbuildinfo # IDEs and editors .idea/ .vscode/ *.swp *.swo .DS_Store **/.DS_Store *.sublime-workspace *.sublime-project # Jupyter Notebook .ipynb_checkpoints */.ipynb_checkpoints/* # Testing .coverage htmlcov/ .pytest_cache/ coverage/ .nyc_output/ # Cloud credentials .google-adc # Logs logs *.log # Python version .python-version ================================================ FILE: .pre-commit-config.yaml ================================================ repos: # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.4.2 hooks: - id: black # It is recommended to specify the latest version of Python # supported by your project here, or alternatively use # pre-commit's default_language_version, see # https://pre-commit.com/#top_level-default_language_version language_version: python3.12 ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to aisuite First off, thanks for taking the time to contribute! All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. > And if you like the project, but just don't have time to contribute, that's fine. There are other easy > ways to support the project and show your appreciation, which we would also be very happy about: > - Star the project > - Tweet about it > - Refer this project in your project's readme > - Mention the project at local meetups and tell your friends/colleagues ## Table of Contents - [I Have a Question](#i-have-a-question) - [I Want To Contribute](#i-want-to-contribute) - [Reporting Bugs](#reporting-bugs) - [Suggesting Enhancements](#suggesting-enhancements) - [Your First Code Contribution](#your-first-code-contribution) - [Improving The Documentation](#improving-the-documentation) - [Styleguides](#styleguides) - [Commit Messages](#commit-messages) ## I Have a Question > If you want to ask a question, we assume that you have read the available > [Documentation](https://github.com/andrewyng/aisuite/blob/main/README.md). Before you ask a question, it is best to search for existing [Issues](https://github.com/andrewyng/aisuite/issues) that might help you. If you find a relevant issue that already exists and still need clarification, please add your question to that existing issue. We also recommend reaching out to the community in the aisuite [Discord](https://discord.gg/T6Nvn8ExSb) server. If you then still feel the need to ask a question and need clarification, we recommend the following: - Open an [Issue](https://github.com/andrewyng/aisuite/issues/new). - Provide as much context as you can about what you're running into. - Provide project and platform versions (python, OS, etc.), depending on what seems relevant. We (or someone in the community) will then take care of the issue as soon as possible. ## I Want To Contribute > ### Legal Notice > When contributing to this project, you must agree that you have authored 100% of the content, that > you have the necessary rights to the content and that the content you contribute may be provided > under the project license. ### Reporting Bugs #### Before Submitting a Bug Report A good bug report shouldn't leave others needing to chase you up for more information. Therefore, we ask you to investigate carefully, collect information and describe the issue in detail in your report. Please complete the following steps in advance to help us fix any potential bug as fast as possible. - Make sure that you are using the latest version. - Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](https://github.com/andrewyng/aisuite/blob/main/README.md). If you are looking for support, you might want to check [this section](#i-have-a-question)). - To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](https://github.com/andrewyng/aisuite?q=label%3Abug). - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue. - Collect information about the bug: - Stack trace (Traceback) - OS, Platform and Version (Windows, Linux, macOS, x86, ARM) - Version of the interpreter, compiler, SDK, runtime environment, package manager, depending on what seems relevant. - Possibly your input and the output - Can you reliably reproduce the issue? And can you also reproduce it with older versions? #### How Do I Submit a Good Bug Report? > You must never report security related issues, vulnerabilities or bugs including sensitive information to > the issue tracker, or elsewhere in public. Instead sensitive bugs must be sent by email to . We use GitHub issues to track bugs and errors. If you run into an issue with the project: - Open an [Issue](https://github.com/andrewyng/aisuite/issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.) - Explain the behavior you would expect and the actual behavior. - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. - Provide the information you collected in the previous section. Once it's filed: - The project team will label the issue accordingly. - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps and mark the issue as `needs-repro`. Bugs with the `needs-repro` tag will not be addressed until they are reproduced. - If the team is able to reproduce the issue, it will be marked `needs-fix`, as well as possibly other tags (such as `critical`), and the issue will be left to be [implemented by someone](#your-first-code-contribution). Please use the issue templates provided. ### Suggesting Enhancements This section guides you through submitting an enhancement suggestion for aisuite, **including completely new features and minor improvements to existing functionality**. Following these guidelines will help maintainers and the community to understand your suggestion and find related suggestions. #### Before Submitting an Enhancement - Make sure that you are using the latest version. - Read the [documentation](https://github.com/andrewyng/aisuite/blob/main/README.md) carefully and find out if the functionality is already covered, maybe by an individual configuration. - Perform a [search](https://github.com/andrewyng/aisuite/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing an add-on/plugin library. #### How Do I Submit a Good Enhancement Suggestion? Enhancement suggestions are tracked as [GitHub issues](https://github.com/andrewyng/aisuite/issues). - Use a **clear and descriptive title** for the issue to identify the suggestion. - Provide a **step-by-step description of the suggested enhancement** in as many details as possible. - **Describe the current behavior** and **explain which behavior you expected to see instead** and why. At this point you can also tell which alternatives do not work for you. - **Explain why this enhancement would be useful** to most aisuite users. You may also want to point out the other projects that solved it better and which could serve as inspiration. ### Your First Code Contribution #### Pre-requisites You should first [fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo) the `aisuite` repository and then clone your forked repository: ```bash git clone https://github.com//aisuite.git ``` Once in the cloned repository directory, make a branch on the forked repository with your username and description of PR: ```bash git checkout -B / ``` Please install the development and test dependencies: ```bash poetry install --with dev,test ``` `aisuite` uses pre-commit to ensure the formatting is consistent: ```bash pre-commit install ``` **Make suggested changes** Afterwards, our suite of formatting tests will run automatically before each `git commit`. You can also run these manually: ```bash pre-commit run --all-files ``` If a formatting test fails, it will fix the modified code in place and abort the `git commit`. After looking over the changes, you can `git add ` and then repeat the previous git commit command. **Note**: a github workflow will check the files with the same formatter and reject the PR if it doesn't pass, so please make sure it passes locally. #### Testing `aisuite` tracks unit tests. Pytest is used to execute said unit tests in `tests/`: ```bash poetry run pytest tests ``` If your code changes implement a new function, please make a corresponding unit test to the `test/*` files. #### Contributing Workflow We actively welcome your pull requests. 1. Create your new branch from main in your forked repo, with your username and a name describing the work you're completing e.g. user-123/add-feature-x. 2. If you've added code that should be tested, add tests. Ensure all tests pass. See the testing section for more information. 3. If you've changed APIs, update the documentation. 4. Make sure your code lints. ### Improving The Documentation We welcome valuable contributions in the form of new documentation or revised documentation that provide further clarity or accuracy. Each function should be clearly documented. Well-documented code is easier to review and understand/extend. ## Styleguides For code documentation, please follow the [Google styleguide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings). ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Andrew Ng Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # aisuite [![PyPI](https://img.shields.io/pypi/v/aisuite)](https://pypi.org/project/aisuite/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) `aisuite` is a lightweight Python library that provides a **unified API for working with multiple Generative AI providers**. It offers a consistent interface for models from *OpenAI, Anthropic, Google, Hugging Face, AWS, Cohere, Mistral, Ollama*, and others—abstracting away SDK differences, authentication details, and parameter variations. Its design is modeled after OpenAI’s API style, making it instantly familiar and easy to adopt. `aisuite` lets developers build and **run LLM-based or agentic applications across providers** with minimal setup. While it’s not a full-blown agents framework, it includes simple abstractions for creating standalone, lightweight agents. It’s designed for low learning curve — so you can focus on building AI systems, not integrating APIs. --- ## Key Features `aisuite` is designed to eliminate the complexity of working with multiple LLM providers while keeping your code simple and portable. Whether you're building a chatbot, an agentic application, or experimenting with different models, `aisuite` provides the abstractions you need without getting in your way. * **Unified API for multiple model providers** – Write your code once and run it with any supported provider. Switch between OpenAI, Anthropic, Google, and others with a single parameter change. * **Easy agentic app or agent creation** – Build multi-turn agentic applications using a single parameter `max_turns`. No need to manually manage tool execution loops. * **Pass Tool calls easily** – Pass real Python functions instead of JSON specs; aisuite handles schema generation and execution automatically. * **MCP tools** – Connect to MCP-based tools without writing boilerplate; aisuite handles connection, schema and execution seamlessly. * **Modular and extensible provider architecture** – Add support for new providers with minimal code. The plugin-style architecture makes extensions straightforward. --- ## Installation You can install just the base `aisuite` package, or install a provider's package along with `aisuite`. Install just the base package without any provider SDKs: ```shell pip install aisuite ``` Install aisuite with a specific provider (e.g., Anthropic): ```shell pip install 'aisuite[anthropic]' ``` Install aisuite with all provider libraries: ```shell pip install 'aisuite[all]' ``` ## Setup To get started, you will need API Keys for the providers you intend to use. You'll need to install the provider-specific library either separately or when installing aisuite. The API Keys can be set as environment variables, or can be passed as config to the aisuite Client constructor. You can use tools like [`python-dotenv`](https://pypi.org/project/python-dotenv/) or [`direnv`](https://direnv.net/) to set the environment variables manually. Please take a look at the `examples` folder to see usage. Here is a short example of using `aisuite` to generate chat completion responses from gpt-4o and claude-3-5-sonnet. Set the API keys. ```shell export OPENAI_API_KEY="your-openai-api-key" export ANTHROPIC_API_KEY="your-anthropic-api-key" ``` Use the python client. ```python import aisuite as ai client = ai.Client() models = ["openai:gpt-4o", "anthropic:claude-3-5-sonnet-20240620"] messages = [ {"role": "system", "content": "Respond in Pirate English."}, {"role": "user", "content": "Tell me a joke."}, ] for model in models: response = client.chat.completions.create( model=model, messages=messages, temperature=0.75 ) print(response.choices[0].message.content) ``` Note that the model name in the create() call uses the format - `:`. `aisuite` will call the appropriate provider with the right parameters based on the provider value. For a list of provider values, you can look at the directory - `aisuite/providers/`. The list of supported providers are of the format - `_provider.py` in that directory. We welcome providers to add support to this library by adding an implementation file in this directory. Please see section below for how to contribute. For more examples, check out the `examples` directory where you will find several notebooks that you can run to experiment with the interface. --- ## Chat Completions The chat API provides a high-level abstraction for model interactions. It supports all core parameters (`temperature`, `max_tokens`, `tools`, etc.) in a provider-agnostic way. ```python response = client.chat.completions.create( model="google:gemini-pro", messages=[{"role": "user", "content": "Summarize this paragraph."}], ) print(response.choices[0].message.content) ``` `aisuite` standardizes request and response structures so you can focus on logic rather than SDK differences. --- ## Tool Calling & Agentic apps `aisuite` provides a simple abstraction for tool/function calling that works across supported providers. This is in addition to the regular abstraction of passing JSON spec of the tool to the `tools` parameter. The tool calling abstraction makes it easy to use tools with different LLMs without changing your code. There are two ways to use tools with `aisuite`: ### 1. Manual Tool Handling This is the default behavior when `max_turns` is not specified. In this mode, you have full control over the tool execution flow. You pass tools using the standard OpenAI JSON schema format, and `aisuite` returns the LLM's tool call requests in the response. You're then responsible for executing the tools, processing results, and sending them back to the model in subsequent requests. This approach is useful when you need: - Fine-grained control over tool execution logic - Custom error handling or validation before executing tools - The ability to selectively execute or skip certain tool calls - Integration with existing tool execution pipelines You can pass tools in the OpenAI tool format: ```python def will_it_rain(location: str, time_of_day: str): """Check if it will rain in a location at a given time today. Args: location (str): Name of the city time_of_day (str): Time of the day in HH:MM format. """ return "YES" tools = [{ "type": "function", "function": { "name": "will_it_rain", "description": "Check if it will rain in a location at a given time today", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "Name of the city" }, "time_of_day": { "type": "string", "description": "Time of the day in HH:MM format." } }, "required": ["location", "time_of_day"] } } }] response = client.chat.completions.create( model="openai:gpt-4o", messages=messages, tools=tools ) ``` ### 2. Automatic Tool Execution When `max_turns` is specified, you can pass a list of callable Python functions as the `tools` parameter. `aisuite` will automatically handle the tool calling flow: ```python def will_it_rain(location: str, time_of_day: str): """Check if it will rain in a location at a given time today. Args: location (str): Name of the city time_of_day (str): Time of the day in HH:MM format. """ return "YES" client = ai.Client() messages = [{ "role": "user", "content": "I live in San Francisco. Can you check for weather " "and plan an outdoor picnic for me at 2pm?" }] # Automatic tool execution with max_turns response = client.chat.completions.create( model="openai:gpt-4o", messages=messages, tools=[will_it_rain], max_turns=2 # Maximum number of back-and-forth tool calls ) print(response.choices[0].message.content) ``` When `max_turns` is specified, `aisuite` will: 1. Send your message to the LLM 2. Execute any tool calls the LLM requests 3. Send the tool results back to the LLM 4. Repeat until the conversation is complete or max_turns is reached In addition to `response.choices[0].message`, there is an additional field `response.choices[0].intermediate_messages` which contains the list of all messages including tool interactions used. This can be used to continue the conversation with the model. For more detailed examples of tool calling, check out the `examples/tool_calling_abstraction.ipynb` notebook. ### Model Context Protocol (MCP) Integration `aisuite` natively supports **MCP**, a standard protocol that allows LLMs to securely call external tools and access data. You can connect to MCP servers—such as a filesystem or database—and expose their tools directly to your model. Read more about MCP here - https://modelcontextprotocol.io/docs/getting-started/intro Install aisuite with MCP support: ```shell pip install 'aisuite[mcp]' ``` You'll also need an MCP server. For example, to use the filesystem server: ```shell npm install -g @modelcontextprotocol/server-filesystem ``` There are two ways to use MCP tools with aisuite: #### Option 1: Config Dict Format (Recommended for Simple Use Cases) ```python import aisuite as ai client = ai.Client() response = client.chat.completions.create( model="openai:gpt-4o", messages=[{"role": "user", "content": "List the files in the current directory"}], tools=[{ "type": "mcp", "name": "filesystem", "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/directory"] }], max_turns=3 ) print(response.choices[0].message.content) ``` #### Option 2: Explicit MCPClient (Recommended for Advanced Use Cases) ```python import aisuite as ai from aisuite.mcp import MCPClient # Create MCP client once, reuse across requests mcp = MCPClient( command="npx", args=["-y", "@modelcontextprotocol/server-filesystem", "/path/to/directory"] ) # Use with aisuite client = ai.Client() response = client.chat.completions.create( model="openai:gpt-4o", messages=[{"role": "user", "content": "List the files"}], tools=mcp.get_callable_tools(), max_turns=3 ) print(response.choices[0].message.content) mcp.close() # Clean up ``` For detailed usage (security filters, tool prefixing, and `MCPClient` management), see [docs/mcp-tools.md](docs/mcp-tools.md). For detailed examples, see `examples/mcp_tools_example.ipynb`. --- ## Extending aisuite: Adding a Provider New providers can be added by implementing a lightweight adapter. The system uses a naming convention for discovery: | Element | Convention | | --------------- | ---------------------------------- | | **Module file** | `_provider.py` | | **Class name** | `Provider` (capitalized) | Example: ```python # providers/openai_provider.py class OpenaiProvider(BaseProvider): ... ``` This convention ensures consistency and enables automatic loading of new integrations. --- ## Contributing Contributions are welcome. Please review the [Contributing Guide](https://github.com/andrewyng/aisuite/blob/main/CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) for discussions. --- ## License Released under the **MIT License** — free for commercial and non-commercial use. --- ================================================ FILE: aisuite/__init__.py ================================================ from .client import Client from .framework.message import Message from .utils.tools import Tools ================================================ FILE: aisuite/client.py ================================================ from .provider import ProviderFactory import os from .utils.tools import Tools from typing import Union, BinaryIO, Optional, Any, Literal from contextlib import ExitStack from .framework.message import ( TranscriptionResponse, ) from .framework.asr_params import ParamValidator # Import MCP utilities for config dict support try: from .mcp.config import is_mcp_config from .mcp.client import MCPClient MCP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False class Client: def __init__( self, provider_configs: dict = {}, extra_param_mode: Literal["strict", "warn", "permissive"] = "warn", ): """ Initialize the client with provider configurations. Use the ProviderFactory to create provider instances. Args: provider_configs (dict): A dictionary containing provider configurations. Each key should be a provider string (e.g., "google" or "aws-bedrock"), and the value should be a dictionary of configuration options for that provider. For example: { "openai": {"api_key": "your_openai_api_key"}, "aws-bedrock": { "aws_access_key": "your_aws_access_key", "aws_secret_key": "your_aws_secret_key", "aws_region": "us-west-2" } } extra_param_mode (str): How to handle unknown ASR parameters. - "strict": Raise ValueError on unknown params (production) - "warn": Log warning on unknown params (default, development) - "permissive": Allow all params without validation (testing) """ self.providers = {} self.provider_configs = provider_configs self.extra_param_mode = extra_param_mode self.param_validator = ParamValidator(extra_param_mode) self._chat = None self._audio = None def _initialize_providers(self): """Helper method to initialize or update providers.""" for provider_key, config in self.provider_configs.items(): provider_key = self._validate_provider_key(provider_key) self.providers[provider_key] = ProviderFactory.create_provider( provider_key, config ) def _validate_provider_key(self, provider_key): """ Validate if the provider key corresponds to a supported provider. """ supported_providers = ProviderFactory.get_supported_providers() if provider_key not in supported_providers: raise ValueError( f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " "Make sure the model string is formatted correctly as 'provider:model'." ) return provider_key def configure(self, provider_configs: Optional[dict] = None): """ Configure the client with provider configurations. """ if provider_configs is None: return self.provider_configs.update(provider_configs) # Providers will be lazily initialized when needed @property def chat(self): """Return the chat API interface.""" if not self._chat: self._chat = Chat(self) return self._chat @property def audio(self): """Return the audio API interface.""" if not self._audio: self._audio = Audio(self) return self._audio class Chat: def __init__(self, client: "Client"): self.client = client self._completions = Completions(self.client) @property def completions(self): """Return the completions interface.""" return self._completions class Completions: def __init__(self, client: "Client"): self.client = client def _process_mcp_configs(self, tools: list) -> tuple[list, list]: """ Process tools list and convert MCP config dicts to callable tools. This method: 1. Detects MCP config dicts ({"type": "mcp", ...}) 2. Creates MCPClient instances from configs 3. Extracts callable tools with filtering and prefixing 4. Mixes MCP tools with regular callable tools 5. Returns both processed tools and MCP clients for cleanup Args: tools: List of tools (mix of callables and MCP configs) Returns: Tuple of (processed_tools, mcp_clients): - processed_tools: List of callable tools only - mcp_clients: List of MCPClient instances to be cleaned up Example: >>> tools = [ ... my_function, ... {"type": "mcp", "name": "fs", "command": "npx", "args": [...]}, ... another_function ... ] >>> callable_tools, mcp_clients = self._process_mcp_configs(tools) >>> # Returns: ([my_function, fs_tool1, fs_tool2, ..., another_function], [mcp_client]) """ if not MCP_AVAILABLE: # If MCP not installed, check if user is trying to use it if any(is_mcp_config(tool) for tool in tools if isinstance(tool, dict)): raise ImportError( "MCP tools require the 'mcp' package. " "Install it with: pip install 'aisuite[mcp]' or pip install mcp" ) return tools, [] processed_tools = [] mcp_clients = [] for tool in tools: if isinstance(tool, dict) and is_mcp_config(tool): # It's an MCP config dict - convert to callable tools try: mcp_client = MCPClient.from_config(tool) mcp_clients.append(mcp_client) # Get tools with config settings mcp_tools = mcp_client.get_callable_tools( allowed_tools=tool.get("allowed_tools"), use_tool_prefix=tool.get("use_tool_prefix", False), ) processed_tools.extend(mcp_tools) except Exception as e: raise ValueError( f"Failed to create MCP client from config: {e}\n" f"Config: {tool}" ) else: # Regular callable tool - pass through processed_tools.append(tool) return processed_tools, mcp_clients def _extract_thinking_content(self, response): """ Extract content between tags if present and store it in reasoning_content. Args: response: The response object from the provider Returns: Modified response object """ if hasattr(response, "choices") and response.choices: message = response.choices[0].message if hasattr(message, "content") and message.content: content = message.content.strip() if content.startswith("") and "" in content: # Extract content between think tags start_idx = len("") end_idx = content.find("") thinking_content = content[start_idx:end_idx].strip() # Store the thinking content message.reasoning_content = thinking_content # Remove the think tags from the original content message.content = content[end_idx + len("") :].strip() return response def _tool_runner( self, provider, model_name: str, messages: list, tools: Any, max_turns: int, **kwargs, ): """ Handle tool execution loop for max_turns iterations. Args: provider: The provider instance to use for completions model_name: Name of the model to use messages: List of conversation messages tools: Tools instance or list of callable tools max_turns: Maximum number of tool execution turns **kwargs: Additional arguments to pass to the provider Returns: The final response from the model with intermediate responses and messages """ # Handle tools validation and conversion if isinstance(tools, Tools): tools_instance = tools kwargs["tools"] = tools_instance.tools() else: # Check if passed tools are callable if not all(callable(tool) for tool in tools): raise ValueError("One or more tools is not callable") tools_instance = Tools(tools) kwargs["tools"] = tools_instance.tools() turns = 0 intermediate_responses = [] # Store intermediate responses intermediate_messages = [] # Store all messages including tool interactions while turns < max_turns: # Make the API call response = provider.chat_completions_create(model_name, messages, **kwargs) response = self._extract_thinking_content(response) # Store intermediate response intermediate_responses.append(response) # Check if there are tool calls in the response tool_calls = ( getattr(response.choices[0].message, "tool_calls", None) if hasattr(response, "choices") else None ) # Store the model's message intermediate_messages.append(response.choices[0].message) if not tool_calls: # Set the intermediate data in the final response response.intermediate_responses = intermediate_responses[ :-1 ] # Exclude final response response.choices[0].intermediate_messages = intermediate_messages return response # Execute tools and get results results, tool_messages = tools_instance.execute_tool(tool_calls) # Add tool messages to intermediate messages intermediate_messages.extend(tool_messages) # Add the assistant's response and tool results to messages messages.extend([response.choices[0].message, *tool_messages]) turns += 1 # Set the intermediate data in the final response response.intermediate_responses = intermediate_responses[ :-1 ] # Exclude final response response.choices[0].intermediate_messages = intermediate_messages return response def create(self, model: str, messages: list, **kwargs): """ Create chat completion based on the model, messages, and any extra arguments. Supports automatic tool execution when max_turns is specified. """ # Check that correct format is used if ":" not in model: raise ValueError( f"Invalid model format. Expected 'provider:model', got '{model}'" ) # Extract the provider key from the model identifier, e.g., "google:gemini-xx" provider_key, model_name = model.split(":", 1) # Validate if the provider is supported supported_providers = ProviderFactory.get_supported_providers() if provider_key not in supported_providers: raise ValueError( f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " "Make sure the model string is formatted correctly as 'provider:model'." ) # Initialize provider if not already initialized # TODO: Add thread-safe provider initialization with lock to prevent race conditions # when multiple threads try to initialize the same provider simultaneously. if provider_key not in self.client.providers: config = self.client.provider_configs.get(provider_key, {}) self.client.providers[provider_key] = ProviderFactory.create_provider( provider_key, config ) provider = self.client.providers.get(provider_key) if not provider: raise ValueError(f"Could not load provider for '{provider_key}'.") # Extract tool-related parameters max_turns = kwargs.pop("max_turns", None) tools = kwargs.pop("tools", None) # Use ExitStack to manage MCP client cleanup automatically with ExitStack() as stack: # Convert MCP config dicts to callable tools and get MCP clients mcp_clients = [] if tools is not None: tools, mcp_clients = self._process_mcp_configs(tools) # Register all MCP clients for automatic cleanup for mcp_client in mcp_clients: stack.enter_context(mcp_client) # Check environment variable before allowing multi-turn tool execution if max_turns is not None and tools is not None: return self._tool_runner( provider, model_name, messages.copy(), tools, max_turns, **kwargs, ) # Default behavior without tool execution # Delegate the chat completion to the correct provider's implementation response = provider.chat_completions_create(model_name, messages, **kwargs) return self._extract_thinking_content(response) class Audio: """Audio API interface.""" def __init__(self, client: "Client"): self.client = client self._transcriptions = Transcriptions(self.client) @property def transcriptions(self): """Return the transcriptions interface.""" return self._transcriptions class Transcriptions: """Transcriptions API interface.""" def __init__(self, client: "Client"): self.client = client def create( self, *, model: str, file: Union[str, BinaryIO], **kwargs, ) -> TranscriptionResponse: """ Create audio transcription with parameter validation. This method uses a pass-through approach with validation: - Common parameters (OpenAI-style) are auto-mapped to provider equivalents - Provider-specific parameters are passed through directly - Unknown parameters are handled based on extra_param_mode Args: model: Provider and model in format 'provider:model' (e.g., 'openai:whisper-1') file: Audio file to transcribe (file path or file-like object) **kwargs: Transcription parameters (provider-specific or common) Common parameters (portable across providers): - language: Language code (e.g., "en") - prompt: Context for the transcription - temperature: Sampling temperature (0-1, OpenAI only) Provider-specific parameters are passed through directly. See provider documentation for valid parameters. Returns: TranscriptionResponse: Unified response (batch or streaming) Raises: ValueError: If model format invalid, provider not supported, or unknown params in strict mode Examples: # Portable code (OpenAI-style params) >>> result = client.audio.transcriptions.create( ... model="openai:whisper-1", ... file="audio.mp3", ... language="en" ... ) # Provider-specific features >>> result = client.audio.transcriptions.create( ... model="deepgram:nova-2", ... file="audio.mp3", ... language="en", # Common param ... punctuate=True, # Deepgram-specific ... diarize=True # Deepgram-specific ... ) """ # Validate model format if ":" not in model: raise ValueError( f"Invalid model format. Expected 'provider:model', got '{model}'" ) # Extract provider and model name provider_key, model_name = model.split(":", 1) # Validate provider is supported supported_providers = ProviderFactory.get_supported_providers() if provider_key not in supported_providers: raise ValueError( f"Invalid provider key '{provider_key}'. " f"Supported providers: {supported_providers}" ) # Validate and map parameters validated_params = self.client.param_validator.validate_and_map( provider_key, kwargs ) # Initialize provider if not already initialized if provider_key not in self.client.providers: config = self.client.provider_configs.get(provider_key, {}) try: self.client.providers[provider_key] = ProviderFactory.create_provider( provider_key, config ) except ImportError as e: raise ValueError(f"Provider '{provider_key}' is not available: {e}") provider = self.client.providers.get(provider_key) if not provider: raise ValueError(f"Could not load provider for '{provider_key}'.") # Check if provider supports audio transcription if not hasattr(provider, "audio") or provider.audio is None: raise ValueError( f"Provider '{provider_key}' does not support audio transcription." ) # Determine if streaming is requested should_stream = validated_params.get("stream", False) # Delegate to provider implementation try: if should_stream: # Check if provider supports output streaming if hasattr(provider.audio, "transcriptions") and hasattr( provider.audio.transcriptions, "create_stream_output" ): return provider.audio.transcriptions.create_stream_output( model_name, file, **validated_params ) else: raise ValueError( f"Provider '{provider_key}' does not support streaming transcription." ) else: # Non-streaming (batch) transcription if hasattr(provider.audio, "transcriptions") and hasattr( provider.audio.transcriptions, "create" ): return provider.audio.transcriptions.create( model_name, file, **validated_params ) else: raise ValueError( f"Provider '{provider_key}' does not support audio transcription." ) except NotImplementedError: raise ValueError( f"Provider '{provider_key}' does not support audio transcription." ) ================================================ FILE: aisuite/design-notes/asr-parameter-design-motivation.md ================================================ # ASR - API Parameter Design Philosophy ## Design Goal: Portable Code with Provider Flexibility The ASR parameter system is designed around a core principle: **developers should write portable code that works across providers, while retaining the ability to use provider-specific features when needed**. This document explains the rationale behind our parameter classification and validation approach. --- ## Mandatory Parameters and Common Mappings ### The Foundation: Minimal Requirements Every transcription needs just two things: - **`model`**: Which model/provider to use - **`file`**: What audio to transcribe By keeping mandatory parameters minimal, we maximize compatibility and reduce the barrier to getting started. ### Common Parameters: Write Once, Run Anywhere Beyond the basics, there are concepts that exist across providers but use different names or formats. We handle three common parameters that auto-map to each provider's native API: **Example: Same code, different providers** ```python # Works with OpenAI result = client.audio.transcriptions.create( model="openai:whisper-1", file="meeting.mp3", language="en", prompt="discussion about API design" ) # Exact same code works with Deepgram result = client.audio.transcriptions.create( model="deepgram:nova-2", file="meeting.mp3", language="en", prompt="discussion about API design" ) ``` Behind the scenes: - **`language`** passes through as `language` for both OpenAI and Deepgram, but expands to `language_code: "en-US"` for Google - **`prompt`** passes as `prompt` to OpenAI, transforms to `keywords: ["discussion", "about", "API", "design"]` for Deepgram, and becomes `speech_contexts: [{"phrases": ["discussion about API design"]}]` for Google - **`temperature`** passes through to OpenAI (which supports it) and is silently ignored by Deepgram and Google (which don't) **Why auto-mapping?** Developers shouldn't need to remember that Google uses `language_code` while others use `language`, or that Deepgram expects a list of keywords. The framework handles these provider quirks transparently, letting you write portable code. --- ## Provider-Specific Features: Pass-Through for Power Users Each provider has unique features that give them competitive advantages. We don't limit you to the "lowest common denominator" - if you need provider-specific functionality, it's available: **Deepgram's advanced features:** ```python result = client.audio.transcriptions.create( model="deepgram:nova-2", file="meeting.mp3", language="en", punctuate=True, # Deepgram-specific diarize=True, # Deepgram-specific sentiment=True, # Deepgram-specific smart_format=True # Deepgram-specific ) ``` **Google's speech contexts:** ```python result = client.audio.transcriptions.create( model="google:latest_long", file="meeting.mp3", language_code="en-US", enable_automatic_punctuation=True, # Google-specific max_alternatives=3, # Google-specific speech_contexts=[{"phrases": ["API", "SDK", "REST"]}] # Google-specific ) ``` These provider-specific parameters pass through directly to the provider's SDK. The framework validates them based on your configured mode (see next section), but doesn't block access to unique features. --- ## Progressive Validation: Safety When You Need It The validation system supports three modes to match different development stages: ### Development Mode: `"warn"` (Default) ```python client = Client(extra_param_mode="warn") ``` Unknown parameters trigger warnings but continue execution. Perfect for exploration and prototyping. You see *"OpenAI doesn't support 'punctuate'"* but your code keeps running. ### Strict Mode: `"strict"` ```python client = Client(extra_param_mode="strict") ``` Unknown parameters raise errors immediately. Use in production to catch typos, configuration mistakes, or provider API changes early. Ensures no silent failures. ### Permissive Mode: `"permissive"` ```python client = Client(extra_param_mode="permissive") ``` All parameters pass through without validation. Use for beta features, experimental parameters, or when providers add new capabilities faster than framework updates. **Progressive workflow:** 1. **Develop** with `warn` - explore freely, see warnings 2. **Refactor** - fix warnings to make code portable 3. **Deploy** with `strict` - ensure production safety --- ## Developer Experience Benefits ### 1. Write Portable Code Naturally The same parameter names work across providers. Switch from OpenAI to Deepgram by changing one word: the model identifier. ### 2. Progressive Enhancement Start with portable common parameters. Add provider-specific features only where you need them. Your core logic remains portable even when using advanced features for specific providers. ### 3. Zero Framework Lock-in Parameter names come directly from provider APIs, not framework abstractions. If you need to remove the framework, you already know the native API - the names are identical. ### 4. Validation That Adapts to You Choose your safety level based on context. Strict for production, warn for development, permissive for bleeding-edge features. The framework supports your workflow rather than constraining it. ### 5. No Documentation Friction Copy parameters from provider docs directly. No need to learn our abstraction layer or figure out mappings - we handle the common cases, you use native names for everything else. --- ## Alternative Design Considered We considered creating a unified options object (`TranscriptionOptions`) that explicitly defines all parameters with framework-specific names. We chose pass-through instead because: 1. **Provider APIs evolve faster than frameworks** - New parameters appear frequently. Pass-through lets developers use them immediately (in permissive mode) without waiting for framework updates. 2. **Provider features don't map cleanly** - Deepgram's sentiment analysis, Google's complex speech contexts, OpenAI's timestamp granularities - each is unique. A unified object means either losing functionality or creating complex provider-specific abstractions. 3. **Direct API access reduces friction** - Developers already know their provider's API from official docs. They can use parameter names directly rather than learning another abstraction layer. The pass-through approach with progressive validation provides the best of both worlds: portability for common cases, power for advanced features, and safety when you need it. --- ## Design Principles Summary - **Mandatory Minimal**: Only `model` and `file` required - **Common Auto-Mapped**: Frequent cross-provider concepts map transparently - **Provider-Specific Pass-Through**: Unique features remain accessible - **Progressive Validation**: Three modes for different development stages - **Zero Abstraction Tax**: Use provider APIs directly with optional safety nets This design prioritizes developer experience through portability without sacrificing power, validation without blocking experimentation, and simplicity without limiting functionality. ================================================ FILE: aisuite/framework/__init__.py ================================================ from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse from .message import Message ================================================ FILE: aisuite/framework/asr_params.py ================================================ """ ASR parameter registry and validation. This module provides a unified parameter validation system for audio transcription across different providers. It supports: - Common parameters (OpenAI-style) that are auto-mapped to provider equivalents - Provider-specific parameters that are passed through directly - Three validation modes: strict, warn, and permissive """ from typing import Dict, Set, Any, Optional, Literal import logging logger = logging.getLogger(__name__) # Common parameters that get auto-mapped across providers # These follow OpenAI's API conventions for maximum portability COMMON_PARAMS: Dict[str, Dict[str, Optional[str]]] = { "language": { "openai": "language", "deepgram": "language", "google": "language_code", "huggingface": None, # Not supported by Inference API }, "prompt": { "openai": "prompt", "deepgram": "keywords", "google": "speech_contexts", "huggingface": None, # Not supported }, "temperature": { "openai": "temperature", "deepgram": None, # Not supported "google": None, # Not supported "huggingface": "temperature", # Supported as generation param }, } # Valid provider-specific parameters # Each provider has its own set of supported parameters PROVIDER_PARAMS: Dict[str, Set[str]] = { "openai": { # Basic parameters "language", "prompt", "temperature", # Output format "response_format", # "json" | "text" | "srt" | "verbose_json" | "vtt" "timestamp_granularities", # ["word"] | ["segment"] | ["word", "segment"] # Streaming "stream", # Boolean }, "deepgram": { # Basic parameters "language", "model", # Text enhancement "punctuate", # Auto-add punctuation "diarize", # Speaker diarization "utterances", # Sentence-level timestamps "paragraphs", # Paragraph segmentation "smart_format", # Format numbers, dates, etc. "profanity_filter", # Filter profanity # Advanced features "search", # Search for keywords: ["keyword1", "keyword2"] "replace", # Replace words: {"um": "", "uh": ""} "keywords", # Boost keywords: ["important", "technical"] "numerals", # Format numerals "measurements", # Format measurements # AI features "sentiment", # Sentiment analysis "topics", # Topic detection "intents", # Intent recognition "summarize", # Auto-summarization # Audio format "encoding", # "linear16" | "mp3" | "flac" "sample_rate", # Integer (Hz) "channels", # Integer # Quality and alternatives "confidence", # Include confidence scores "alternatives", # Number of alternative transcripts # Streaming "interim_results", # Get interim results while streaming }, "google": { # Basic parameters "language_code", # BCP-47 code like "en-US" "model", # "latest_long" | "latest_short" | "default" # Audio format "encoding", # "LINEAR16" | "FLAC" | "MP3" "sample_rate_hertz", # Integer "audio_channel_count", # Integer # Text enhancement "enable_automatic_punctuation", # Boolean "profanity_filter", # Boolean "enable_spoken_punctuation", # Boolean "enable_spoken_emojis", # Boolean # Speaker features "enable_speaker_diarization", # Boolean "diarization_speaker_count", # Integer (max speakers) "min_speaker_count", # Integer # Metadata "enable_word_time_offsets", # Word-level timestamps "enable_word_confidence", # Word-level confidence "max_alternatives", # Number of alternatives # Context "speech_contexts", # [{"phrases": [...], "boost": float}] "boost", # Float (phraseHint boost) # Streaming "interim_results", # Boolean "single_utterance", # Boolean (stop after one utterance) }, "huggingface": { # Basic parameters "model", # Model ID on Hugging Face Hub "temperature", # Generation temperature # API options "return_timestamps", # Boolean or "word" or "char" "use_cache", # Boolean: use cached inference "wait_for_model", # Boolean: wait if model is loading # Generation parameters "top_k", # Integer: top-k sampling "top_p", # Float: nucleus sampling "max_length", # Integer: maximum output length "do_sample", # Boolean: enable sampling }, } # Language code expansion for Google (2-letter to locale codes) GOOGLE_LANGUAGE_MAP = { "en": "en-US", "es": "es-ES", "fr": "fr-FR", "de": "de-DE", "it": "it-IT", "pt": "pt-BR", "ja": "ja-JP", "ko": "ko-KR", "zh": "zh-CN", "ar": "ar-SA", "hi": "hi-IN", "ru": "ru-RU", "nl": "nl-NL", "pl": "pl-PL", "sv": "sv-SE", "da": "da-DK", "no": "nb-NO", "fi": "fi-FI", "tr": "tr-TR", "th": "th-TH", "vi": "vi-VN", } class ParamValidator: """ Validates and maps ASR parameters for different providers. This class handles three types of parameters: 1. Common parameters (OpenAI-style) - auto-mapped to provider equivalents 2. Provider-specific parameters - passed through with validation 3. Unknown parameters - handled based on extra_param_mode """ def __init__(self, extra_param_mode: Literal["strict", "warn", "permissive"]): """ Initialize the parameter validator. Args: extra_param_mode: How to handle unknown parameters - "strict": Raise ValueError on unknown params - "warn": Log warning on unknown params (default) - "permissive": Allow all params without validation """ self.extra_param_mode = extra_param_mode def validate_and_map( self, provider_key: str, params: Dict[str, Any] ) -> Dict[str, Any]: """ Validate and map parameters for the given provider. This method: 1. Maps common parameters to provider-specific equivalents 2. Validates provider-specific parameters 3. Handles unknown parameters based on extra_param_mode Args: provider_key: Provider identifier (e.g., "openai", "deepgram") params: Raw parameters from user Returns: Validated and mapped parameters ready for provider API Raises: ValueError: If extra_param_mode="strict" and unknown params found """ result = {} unknown_params = [] provider_params = PROVIDER_PARAMS.get(provider_key, set()) for key, value in params.items(): # Check if it's a common param that needs mapping if key in COMMON_PARAMS: mapped_key = COMMON_PARAMS[key].get(provider_key) # Provider doesn't support this common param if mapped_key is None: logger.debug( f"Parameter '{key}' not supported by {provider_key}, ignoring" ) continue # Transform value if needed (e.g., "en" -> "en-US" for Google) mapped_value = self._transform_value(provider_key, key, value) result[mapped_key] = mapped_value # Check if it's a valid provider-specific param elif key in provider_params: result[key] = value # Unknown parameter else: unknown_params.append(key) # Handle unknown parameters based on mode if unknown_params: self._handle_unknown(provider_key, unknown_params) # In permissive mode, still pass them through if self.extra_param_mode == "permissive": for key in unknown_params: result[key] = params[key] return result def _transform_value(self, provider_key: str, param_key: str, value: Any) -> Any: """ Transform parameter values during mapping. This handles provider-specific transformations like: - Google: Expanding "en" to "en-US" - Google: Wrapping prompt in speech_contexts structure - Deepgram: Converting prompt string to keywords list Args: provider_key: Provider identifier param_key: Parameter name (from COMMON_PARAMS) value: Parameter value to transform Returns: Transformed parameter value """ # Google: Expand 2-letter language codes to locale codes if provider_key == "google" and param_key == "language": if isinstance(value, str) and len(value) == 2: return GOOGLE_LANGUAGE_MAP.get(value, f"{value}-US") # Google: Wrap prompt in speech_contexts structure if provider_key == "google" and param_key == "prompt": return [{"phrases": [value]}] # Deepgram: Split prompt into keywords list if provider_key == "deepgram" and param_key == "prompt": if isinstance(value, str): return value.split() return value return value def _handle_unknown(self, provider_key: str, unknown_params: list): """ Handle unknown parameters based on extra_param_mode. Args: provider_key: Provider identifier unknown_params: List of unknown parameter names Raises: ValueError: If extra_param_mode="strict" """ msg = ( f"Unknown parameters for {provider_key}: {unknown_params}. " f"See {provider_key} documentation for valid parameters." ) if self.extra_param_mode == "strict": raise ValueError(msg) elif self.extra_param_mode == "warn": import warnings warnings.warn(msg, UserWarning) # permissive mode: do nothing ================================================ FILE: aisuite/framework/chat_completion_response.py ================================================ """Defines the ChatCompletionResponse class.""" from typing import Optional from aisuite.framework.choice import Choice from aisuite.framework.message import CompletionUsage # pylint: disable=too-few-public-methods class ChatCompletionResponse: """Used to conform to the response model of OpenAI.""" def __init__(self): """Initializes the ChatCompletionResponse.""" self.choices = [Choice()] # Adjust the range as needed for more choices self.usage: Optional[CompletionUsage] = None ================================================ FILE: aisuite/framework/choice.py ================================================ from aisuite.framework.message import Message from typing import Literal, Optional, List class Choice: def __init__(self): self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None self.message = Message( content=None, tool_calls=None, role="assistant", refusal=None, reasoning_content=None, ) self.intermediate_messages: List[Message] = [] ================================================ FILE: aisuite/framework/message.py ================================================ """ Interface to hold contents of api responses when they do not confirm to the OpenAI style response. """ from typing import Literal, Optional, List, AsyncGenerator, Union, Dict, Any from pydantic import BaseModel from dataclasses import dataclass, field class Function(BaseModel): """Represents a function call.""" arguments: str name: str class ChatCompletionMessageToolCall(BaseModel): """Represents a tool call in a chat completion message.""" id: str function: Function type: Literal["function"] class Message(BaseModel): """Represents a message in a chat completion.""" content: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None role: Optional[Literal["user", "assistant", "system", "tool"]] = None refusal: Optional[str] = None class CompletionTokensDetails(BaseModel): """Details about the tokens used in a completion.""" accepted_prediction_tokens: Optional[int] = None """ When using Predicted Outputs, the number of tokens in the prediction that appeared in the completion. """ audio_tokens: Optional[int] = None """Audio input tokens generated by the model.""" reasoning_tokens: Optional[int] = None """Tokens generated by the model for reasoning.""" rejected_prediction_tokens: Optional[int] = None """ When using Predicted Outputs, the number of tokens in the prediction that did not appear in the completion. However, like reasoning tokens, these tokens are still counted in the total completion tokens for purposes of billing, output, and context window limits. """ class PromptTokensDetails(BaseModel): """Details about the tokens used in a prompt.""" text_tokens: Optional[int] = None """Tokens generated by the model for text.""" audio_tokens: Optional[int] = None """Audio input tokens present in the prompt.""" cached_tokens: Optional[int] = None """Cached tokens present in the prompt.""" class CompletionUsage(BaseModel): """Represents the token usage for a completion.""" completion_tokens: Optional[int] = None """Number of tokens in the generated completion.""" prompt_tokens: Optional[int] = None """Number of tokens in the prompt.""" total_tokens: Optional[int] = None """Total number of tokens used in the request (prompt + completion).""" completion_tokens_details: Optional[CompletionTokensDetails] = None """Breakdown of tokens used in a completion.""" prompt_tokens_details: Optional[PromptTokensDetails] = None """Breakdown of tokens used in the prompt.""" class Word(BaseModel): """Represents a single word with timing information.""" word: str start: float end: float confidence: Optional[float] = None # Common across Deepgram, Azure, AWS speaker: Optional[int] = None # Speaker diarization (Deepgram, Azure, AWS) speaker_confidence: Optional[float] = None # Speaker identification confidence punctuated_word: Optional[str] = None # Word with punctuation (some providers) class Segment(BaseModel): """Represents a segment of transcribed text with detailed information.""" id: int seek: int start: float end: float text: str # OpenAI Whisper specific fields tokens: Optional[List[int]] = None temperature: Optional[float] = None avg_logprob: Optional[float] = None compression_ratio: Optional[float] = None no_speech_prob: Optional[float] = None # Common ASR provider fields confidence: Optional[float] = None # Segment-level confidence speaker: Optional[int] = None # Primary speaker for this segment speaker_confidence: Optional[float] = None # Speaker identification confidence words: Optional[List[Word]] = None # Words within this segment class Alternative(BaseModel): """Represents an alternative transcription hypothesis (common in many ASR APIs).""" transcript: str confidence: Optional[float] = None words: Optional[List[Word]] = None class Channel(BaseModel): """Represents a single audio channel (for multi-channel audio).""" alternatives: List[Alternative] search: Optional[List[dict]] = None # Search results if keyword search enabled class TranscriptionResult(BaseModel): """ Unified transcription result format supporting multiple ASR providers. Based on OpenAI Whisper API but extended for common ASR features. """ # Core fields (supported by most providers) text: str language: Optional[str] = None confidence: Optional[float] = None # Overall transcription confidence # OpenAI Whisper specific fields task: Optional[str] = None # "transcribe" or "translate" duration: Optional[float] = None segments: Optional[List[Segment]] = None words: Optional[List[Word]] = None # Multi-channel and alternatives support (Deepgram, Azure, etc.) channels: Optional[List[Channel]] = None alternatives: Optional[List[Alternative]] = None # Advanced features (various providers) utterances: Optional[List[dict]] = None # Speaker utterances paragraphs: Optional[List[dict]] = None # Paragraph detection topics: Optional[List[dict]] = None # Topic detection intents: Optional[List[dict]] = None # Intent recognition sentiment: Optional[dict] = None # Sentiment analysis summary: Optional[dict] = None # Auto-summarization # Metadata metadata: Optional[dict] = None # Provider-specific metadata model_info: Optional[dict] = None # Model information class StreamingTranscriptionChunk(BaseModel): """Represents a single chunk of streaming transcription data.""" text: str is_final: bool confidence: Optional[float] = None start_time: Optional[float] = None end_time: Optional[float] = None speaker_id: Optional[int] = None speaker_confidence: Optional[float] = None words: Optional[List[Word]] = None sequence_number: Optional[int] = None channel: Optional[int] = None provider_data: Optional[dict] = None # Type alias for streaming transcription responses StreamingTranscriptionResponse = AsyncGenerator[StreamingTranscriptionChunk, None] # Union type for both batch and streaming responses TranscriptionResponse = Union[TranscriptionResult, StreamingTranscriptionResponse] @dataclass class TranscriptionOptions: """Unified transcription options for ASR providers.""" # Core parameters language: Optional[str] = None # Audio format parameters audio_format: Optional[str] = None sample_rate: Optional[int] = None channels: Optional[int] = None encoding: Optional[str] = None # Audio encoding type # Output format response_format: Optional[str] = None include_word_timestamps: Optional[bool] = None include_segment_timestamps: Optional[bool] = None timestamp_granularities: Optional[List[str]] = None # OpenAI: ["word", "segment"] # Context and guidance prompt: Optional[str] = None context_phrases: Optional[List[str]] = None boost_phrases: Optional[List[str]] = None # Speaker features enable_speaker_diarization: Optional[bool] = None max_speakers: Optional[int] = None min_speakers: Optional[int] = None # Text processing enable_automatic_punctuation: Optional[bool] = None enable_profanity_filter: Optional[bool] = None enable_smart_formatting: Optional[bool] = None enable_word_confidence: Optional[bool] = None enable_spoken_punctuation: Optional[bool] = None enable_spoken_emojis: Optional[bool] = None # Advanced features enable_sentiment_analysis: Optional[bool] = None enable_topic_detection: Optional[bool] = None enable_intent_recognition: Optional[bool] = None enable_summarization: Optional[bool] = None enable_translation: Optional[bool] = None translation_target_language: Optional[str] = None # Confidence and alternatives include_confidence_scores: Optional[bool] = None max_alternatives: Optional[int] = None # Processing options temperature: Optional[float] = None interim_results: Optional[bool] = None vad_sensitivity: Optional[float] = None stream: Optional[bool] = None # Enable streaming output # Custom parameters custom_parameters: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): """Validate parameters and constraints.""" # Validate constraints if self.temperature is not None and not (0.0 <= self.temperature <= 1.0): raise ValueError("temperature must be between 0.0 and 1.0") if self.max_speakers is not None and self.max_speakers < 1: raise ValueError("max_speakers must be at least 1") if self.min_speakers is not None and self.min_speakers < 1: raise ValueError("min_speakers must be at least 1") if ( self.max_speakers is not None and self.min_speakers is not None and self.min_speakers > self.max_speakers ): raise ValueError("min_speakers cannot be greater than max_speakers") if self.vad_sensitivity is not None and not ( 0.0 <= self.vad_sensitivity <= 1.0 ): raise ValueError("vad_sensitivity must be between 0.0 and 1.0") def has_any_parameters(self) -> bool: """Check if any parameters are set.""" for field_name, field_value in self.__dict__.items(): if field_name == "custom_parameters": if field_value: return True elif field_value is not None: return True return False def get_set_parameters(self) -> Dict[str, Any]: """Get only the parameters that are set.""" set_params = {} for field_name, field_value in self.__dict__.items(): if field_name == "custom_parameters": if field_value: set_params[field_name] = field_value elif field_value is not None: set_params[field_name] = field_value return set_params ================================================ FILE: aisuite/framework/parameter_mapper.py ================================================ """ Parameter mapping utilities for ASR providers. Maps unified TranscriptionOptions to provider-specific parameters. """ from typing import Dict, Any, List, TYPE_CHECKING if TYPE_CHECKING: from .message import TranscriptionOptions class ParameterMapper: """Maps unified TranscriptionOptions to provider-specific parameters.""" # OpenAI Whisper API parameter mapping OPENAI_MAPPING = { "language": "language", "response_format": "response_format", "temperature": "temperature", "prompt": "prompt", "stream": "stream", "timestamp_granularities": "timestamp_granularities", } # Deepgram API parameter mapping DEEPGRAM_MAPPING = { "language": "language", "enable_automatic_punctuation": "punctuate", "enable_smart_formatting": "smart_format", "enable_speaker_diarization": "diarize", "include_word_timestamps": "utterances", "include_segment_timestamps": "paragraphs", "context_phrases": "keywords", "enable_profanity_filter": "profanity_filter", "enable_sentiment_analysis": "sentiment", "enable_topic_detection": "topics", "enable_intent_recognition": "intents", "enable_summarization": "summarize", "interim_results": "interim_results", "channels": "channels", "sample_rate": "sample_rate", "include_confidence_scores": "confidence", "enable_word_confidence": "confidence", "max_alternatives": "alternatives", "stream": "interim_results", "encoding": "encoding", # timestamp_granularities is handled specially for Deepgram } # Google API parameter mapping GOOGLE_MAPPING = { "language": "language_code", "sample_rate": "sample_rate_hertz", "channels": "audio_channel_count", "enable_automatic_punctuation": "enable_automatic_punctuation", "enable_speaker_diarization": "enable_speaker_diarization", "max_speakers": "diarization_speaker_count", "min_speakers": "min_speaker_count", "include_word_timestamps": "enable_word_time_offsets", "include_confidence_scores": "enable_word_confidence", "enable_word_confidence": "enable_word_confidence", "context_phrases": "speech_contexts", "enable_profanity_filter": "profanity_filter", "max_alternatives": "max_alternatives", "boost_phrases": "speech_contexts", "audio_format": "encoding", "encoding": "encoding", "interim_results": "interim_results", "stream": "interim_results", "enable_spoken_punctuation": "enable_spoken_punctuation", "enable_spoken_emojis": "enable_spoken_emojis", } @classmethod def map_to_openai(cls, options: "TranscriptionOptions") -> Dict[str, Any]: """Map TranscriptionOptions to OpenAI Whisper API parameters.""" params = {} # Handle timestamp granularities timestamp_granularities = [] if options.include_word_timestamps: timestamp_granularities.append("word") if options.include_segment_timestamps: timestamp_granularities.append("segment") if timestamp_granularities: params["timestamp_granularities"] = timestamp_granularities # Map other parameters for opt_key, api_key in cls.OPENAI_MAPPING.items(): if hasattr(options, opt_key): value = getattr(options, opt_key) if value is not None and not opt_key.startswith("include_"): params[api_key] = value # Handle custom parameters cls._apply_custom_parameters(params, options.custom_parameters, "openai") return params @classmethod def map_to_deepgram(cls, options: "TranscriptionOptions") -> Dict[str, Any]: """Map TranscriptionOptions to Deepgram API parameters.""" params = {} for opt_key, api_key in cls.DEEPGRAM_MAPPING.items(): if hasattr(options, opt_key): value = getattr(options, opt_key) if value is not None: params[api_key] = value # Handle special cases if options.context_phrases: params["keywords"] = options.context_phrases # Handle timestamp_granularities conversion for Deepgram if ( hasattr(options, "timestamp_granularities") and options.timestamp_granularities ): if "word" in options.timestamp_granularities: params["utterances"] = True if "segment" in options.timestamp_granularities: params["paragraphs"] = True # Handle custom parameters cls._apply_custom_parameters(params, options.custom_parameters, "deepgram") return params @classmethod def map_to_google(cls, options: "TranscriptionOptions") -> Dict[str, Any]: """Map TranscriptionOptions to Google Speech-to-Text API parameters.""" params = {} for opt_key, api_key in cls.GOOGLE_MAPPING.items(): if hasattr(options, opt_key): value = getattr(options, opt_key) if value is not None: if opt_key == "context_phrases" or opt_key == "boost_phrases": if "speech_contexts" not in params: params["speech_contexts"] = [] params["speech_contexts"].append({"phrases": value}) elif opt_key == "language": # Handle language code conversion for Google # Google expects BCP-47 locale codes like "en-US", not just "en" if len(value) == 2: # Convert "en" to "en-US" language_map = { "en": "en-US", "es": "es-ES", "fr": "fr-FR", "de": "de-DE", "it": "it-IT", "pt": "pt-BR", # Portuguese -> Brazilian Portuguese "ja": "ja-JP", "ko": "ko-KR", "zh": "zh-CN", # Chinese -> Simplified Chinese "ar": "ar-SA", # Arabic -> Saudi Arabia "hi": "hi-IN", # Hindi -> India "ru": "ru-RU", # Russian -> Russia "nl": "nl-NL", # Dutch -> Netherlands "pl": "pl-PL", # Polish -> Poland "sv": "sv-SE", # Swedish -> Sweden "da": "da-DK", # Danish -> Denmark "no": "nb-NO", # Norwegian -> Norway "fi": "fi-FI", # Finnish -> Finland "tr": "tr-TR", # Turkish -> Turkey "th": "th-TH", # Thai -> Thailand "vi": "vi-VN", # Vietnamese -> Vietnam } params[api_key] = language_map.get(value, f"{value}-US") else: params[api_key] = value else: params[api_key] = value # Handle audio encoding mapping if options.audio_format: encoding_map = { "wav": "LINEAR16", "flac": "FLAC", "mp3": "MP3", "ogg": "OGG_OPUS", "webm": "WEBM_OPUS", } params["encoding"] = encoding_map.get( options.audio_format.lower(), "LINEAR16" ) # Handle timestamp_granularities conversion for Google if ( hasattr(options, "timestamp_granularities") and options.timestamp_granularities ): if "word" in options.timestamp_granularities: params["enable_word_time_offsets"] = True # Handle custom parameters cls._apply_custom_parameters(params, options.custom_parameters, "google") return params @classmethod def _apply_custom_parameters( cls, params: Dict[str, Any], custom_params: Dict[str, Any], provider: str ): """ Apply custom parameters for the specific provider. Only provider-namespaced parameters are supported. Parameters not under a provider key are IGNORED. """ if not custom_params: return # Provider-specific namespacing ONLY # Users MUST structure custom_parameters like: # { # "openai": {"response_format": "srt", "temperature": 0.2}, # "deepgram": {"search": ["keyword"], "numerals": True}, # "google": {"use_enhanced": True, "adaptation": {...}} # } if provider in custom_params: params.update(custom_params[provider]) # Note: Any parameters not under a provider key are ignored ================================================ FILE: aisuite/framework/provider_interface.py ================================================ """The shared interface for model providers.""" # TODO(rohit): Remove this. This interface is obsolete in favor of Provider. class ProviderInterface: """Defines the expected behavior for provider-specific interfaces.""" def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: """Create a chat completion using the specified messages, model, and temperature. This method must be implemented by subclasses to perform completions. Args: ---- messages (list): The chat history. model (str): The identifier of the model to be used in the completion. temperature (float): The temperature to use in the completion. Raises: ------ NotImplementedError: If this method has not been implemented by a subclass. """ raise NotImplementedError( "Provider Interface has not implemented chat_completion_create()" ) ================================================ FILE: aisuite/mcp/__init__.py ================================================ """ MCP (Model Context Protocol) integration for aisuite. This module provides support for using MCP servers and their tools with aisuite's unified interface for AI providers. MCP allows AI applications to connect to external data sources and tools through a standardized protocol. This integration makes MCP tools available as Python callables that work seamlessly with aisuite's existing tool calling infrastructure. Example: >>> from aisuite import Client >>> from aisuite.mcp import MCPClient >>> >>> # Connect to an MCP server >>> mcp = MCPClient( ... command="npx", ... args=["-y", "@modelcontextprotocol/server-filesystem", "/docs"] ... ) >>> >>> # Use MCP tools with any provider >>> client = Client() >>> response = client.chat.completions.create( ... model="openai:gpt-4o", ... messages=[{"role": "user", "content": "Read README.md"}], ... tools=mcp.get_callable_tools(), ... max_turns=2 ... ) """ from .client import MCPClient __all__ = ["MCPClient"] ================================================ FILE: aisuite/mcp/client.py ================================================ """ MCP Client for aisuite. This module provides the MCPClient class that connects to MCP servers and exposes their tools as Python callables compatible with aisuite's tool system. """ import asyncio import json from typing import Any, Callable, Dict, List, Optional from contextlib import contextmanager try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import httpx except ImportError as e: if "mcp" in str(e): raise ImportError( "MCP support requires the 'mcp' package. " "Install it with: pip install 'aisuite[mcp]' or pip install mcp" ) elif "httpx" in str(e): raise ImportError( "HTTP transport requires the 'httpx' package. " "Install it with: pip install httpx" ) raise from .tool_wrapper import create_mcp_tool_wrapper from .config import MCPConfig, validate_mcp_config, get_transport_type class MCPClient: """ Client for connecting to MCP servers and using their tools with aisuite. This class manages the connection to an MCP server, discovers available tools, and creates Python callable wrappers that work seamlessly with aisuite's existing tool calling infrastructure. Example: >>> # Connect to an MCP server >>> mcp = MCPClient( ... command="npx", ... args=["-y", "@modelcontextprotocol/server-filesystem", "/path"] ... ) >>> >>> # Get tools and use with aisuite >>> import aisuite as ai >>> client = ai.Client() >>> response = client.chat.completions.create( ... model="openai:gpt-4o", ... messages=[{"role": "user", "content": "List files"}], ... tools=mcp.get_callable_tools(), ... max_turns=2 ... ) The MCPClient handles: - Starting and managing the MCP server process - Performing the MCP handshake - Discovering available tools - Creating callable wrappers for tools - Executing tool calls via the MCP protocol """ def __init__( self, command: Optional[str] = None, args: Optional[List[str]] = None, env: Optional[Dict[str, str]] = None, server_url: Optional[str] = None, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, name: Optional[str] = None, ): """ Initialize the MCP client and connect to an MCP server. Supports both stdio and HTTP transports. Provide either stdio parameters (command) OR HTTP parameters (server_url), but not both. Args: command: Command to start the MCP server (e.g., "npx", "python") - for stdio transport args: Arguments to pass to the command (e.g., ["-y", "server-package"]) - for stdio transport env: Optional environment variables for the server process - for stdio transport server_url: Base URL of the MCP server (e.g., "http://localhost:8000") - for HTTP transport headers: Optional HTTP headers (e.g., for authentication) - for HTTP transport timeout: Request timeout in seconds - for HTTP transport (default: 30.0) name: Optional name for this MCP client (used for logging and prefixing) Raises: ImportError: If the mcp or httpx package is not installed ValueError: If both stdio and HTTP parameters are provided, or neither RuntimeError: If connection to the MCP server fails """ # Validate transport parameters has_stdio = command is not None has_http = server_url is not None if not (has_stdio ^ has_http): raise ValueError( "Must provide exactly one transport: either 'command' (stdio) or 'server_url' (HTTP)." ) # Store parameters based on transport type if has_stdio: self.server_params = StdioServerParameters( command=command, args=args or [], env=env, ) self.name = name or command # Stdio-specific state self._session: Optional[ClientSession] = None self._read = None self._write = None self._stdio_context = None else: # HTTP self.server_url = server_url self.headers = headers or {} self.timeout = timeout self.name = name or server_url # HTTP-specific state (initialized in _async_connect_http) self._http_client = None self._request_id = 0 self._session_id: Optional[str] = None # MCP session ID from server # Shared state self._tools_cache: Optional[List[Dict[str, Any]]] = None self._event_loop: Optional[asyncio.AbstractEventLoop] = None # Initialize connection self._connect() @classmethod def from_config(cls, config: Dict[str, Any]) -> "MCPClient": """ Create an MCPClient from a configuration dictionary. This method validates the config and creates an MCPClient instance. It supports both stdio and HTTP transports. Args: config: MCP configuration dictionary Returns: MCPClient instance Raises: ValueError: If configuration is invalid Example (stdio): >>> config = { ... "type": "mcp", ... "name": "filesystem", ... "command": "npx", ... "args": ["-y", "@modelcontextprotocol/server-filesystem", "/docs"] ... } >>> mcp = MCPClient.from_config(config) Example (HTTP): >>> config = { ... "type": "mcp", ... "name": "api-server", ... "server_url": "http://localhost:8000", ... "headers": {"Authorization": "Bearer token"} ... } >>> mcp = MCPClient.from_config(config) """ # Validate and normalize config validated_config = validate_mcp_config(config) # Determine transport type transport = get_transport_type(validated_config) if transport == "stdio": return cls( command=validated_config["command"], args=validated_config.get("args", []), env=validated_config.get("env"), name=validated_config["name"], ) else: # http return cls( server_url=validated_config["server_url"], headers=validated_config.get("headers"), timeout=validated_config.get("timeout", 30.0), name=validated_config["name"], ) @staticmethod def get_tools_from_config(config: Dict[str, Any]) -> List[Callable]: """ Convenience method to create MCPClient and get callable tools from config. This is a helper that combines from_config() and get_callable_tools() in a single call. It respects the config's allowed_tools and use_tool_prefix settings. Args: config: MCP configuration dictionary Returns: List of callable tool wrappers Example: >>> config = { ... "type": "mcp", ... "name": "filesystem", ... "command": "npx", ... "args": ["..."], ... "allowed_tools": ["read_file"], ... "use_tool_prefix": True ... } >>> tools = MCPClient.get_tools_from_config(config) >>> # Returns callable tools filtered and prefixed per config """ # Validate config first validated_config = validate_mcp_config(config) # Create client client = MCPClient.from_config(validated_config) # Get tools with config settings tools = client.get_callable_tools( allowed_tools=validated_config.get("allowed_tools"), use_tool_prefix=validated_config.get("use_tool_prefix", False), ) return tools def _connect(self): """ Establish connection to the MCP server. This method: 1. Creates an event loop if needed 2. Detects transport type (stdio or HTTP) 3. Establishes connection via appropriate transport 4. Performs the MCP initialization handshake 5. Caches the available tools Note: Automatically handles Jupyter/IPython environments where an event loop is already running by using nest_asyncio. """ # Get or create event loop try: self._event_loop = asyncio.get_running_loop() except RuntimeError: self._event_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._event_loop) # Enable nested event loops for Jupyter/IPython compatibility # This allows run_until_complete() to work in environments where # an event loop is already running (like Jupyter notebooks) try: import nest_asyncio nest_asyncio.apply() except ImportError: # nest_asyncio not available - will work fine in regular Python # but may fail in Jupyter. User should install: pip install nest-asyncio pass # Detect transport type and run appropriate async connection if hasattr(self, "server_url"): # HTTP transport self._event_loop.run_until_complete(self._async_connect_http()) else: # Stdio transport self._event_loop.run_until_complete(self._async_connect()) async def _async_connect(self): """Async connection initialization for stdio transport.""" # Start the MCP server and store the context manager self._stdio_context = stdio_client(self.server_params) self._read, self._write = await self._stdio_context.__aenter__() # Create session self._session = ClientSession(self._read, self._write) await self._session.__aenter__() # Initialize connection await self._session.initialize() # List available tools and cache them tools_result = await self._session.list_tools() # Convert Tool objects to dicts for easier handling if hasattr(tools_result, "tools"): self._tools_cache = [ { "name": tool.name, "description": ( tool.description if hasattr(tool, "description") else "" ), "inputSchema": ( tool.inputSchema if hasattr(tool, "inputSchema") else {} ), } for tool in tools_result.tools ] else: self._tools_cache = [] async def _parse_sse_response( self, response: httpx.Response, request_id: int ) -> Dict[str, Any]: """ Parse SSE stream and extract JSON-RPC response. SSE format per spec: data: {"jsonrpc": "2.0", "id": 1, "result": {...}} data: {"jsonrpc": "2.0", "method": "notification", ...} The server may send multiple events (notifications, requests) before sending the final response. We collect events until we find the response matching our request_id. Args: response: HTTP response with text/event-stream content type request_id: The JSON-RPC request ID to match Returns: Response result dictionary Raises: RuntimeError: If server returns an error or no matching response found """ result = None async for line in response.aiter_lines(): line = line.strip() # Skip empty lines and comments if not line or line.startswith(":"): continue # Parse SSE data field if line.startswith("data: "): data = line[6:] # Remove 'data: ' prefix try: message = json.loads(data) # Check if this is the response to our request if message.get("id") == request_id: if "error" in message: error = message["error"] raise RuntimeError( f"MCP server error: {error.get('message', 'Unknown error')} " f"(code: {error.get('code', 'unknown')})" ) result = message.get("result", {}) # Found our response, can stop parsing break # Note: Server may send other notifications/requests # which we ignore for now (future enhancement for bidirectional comms) except json.JSONDecodeError: # Invalid JSON in SSE data, skip this event continue if result is None: raise RuntimeError( f"No response received in SSE stream for request {request_id}" ) return result async def _send_http_request( self, method: str, params: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Send JSON-RPC request to MCP server via HTTP. Args: method: JSON-RPC method name params: Optional parameters Returns: Response result Raises: RuntimeError: If HTTP request fails or server returns an error """ # Increment request ID self._request_id += 1 # Build JSON-RPC 2.0 request request_data = { "jsonrpc": "2.0", "id": self._request_id, "method": method, } if params: request_data["params"] = params # Use the exact server URL provided by the user url = self.server_url.rstrip("/") # Build headers: MCP requires Accept header with both content types # Merge with any user-provided headers and session ID request_headers = { "Accept": "application/json, text/event-stream", } if self._session_id: request_headers["Mcp-Session-Id"] = self._session_id if self.headers: request_headers.update(self.headers) try: response = await self._http_client.post( url, json=request_data, headers=request_headers ) response.raise_for_status() # Check for MCP session ID in response headers if "Mcp-Session-Id" in response.headers and not self._session_id: self._session_id = response.headers["Mcp-Session-Id"] # Check Content-Type to determine response format content_type = response.headers.get("content-type", "").lower() if "application/json" in content_type: # Handle JSON response (simple request-response) result = response.json() # Check for JSON-RPC error if "error" in result: error = result["error"] raise RuntimeError( f"MCP server error: {error.get('message', 'Unknown error')} " f"(code: {error.get('code', 'unknown')})" ) return result.get("result", {}) elif "text/event-stream" in content_type: # Handle SSE stream response return await self._parse_sse_response(response, request_data["id"]) else: raise RuntimeError( f"Unexpected Content-Type from MCP server: {content_type}" ) except httpx.HTTPError as e: raise RuntimeError( f"HTTP request to MCP server failed: {type(e).__name__}: {str(e)}" ) async def _send_notification( self, method: str, params: Optional[Dict[str, Any]] = None ): """ Send a JSON-RPC notification (no response expected). Notifications are JSON-RPC messages without an ID field. Per the spec, the server should not send a response. Args: method: JSON-RPC method name params: Optional parameters """ # Build JSON-RPC notification (no id field) notification = { "jsonrpc": "2.0", "method": method, } if params: notification["params"] = params # Build headers url = self.server_url.rstrip("/") request_headers = { "Accept": "application/json, text/event-stream", } if self._session_id: request_headers["Mcp-Session-Id"] = self._session_id if self.headers: request_headers.update(self.headers) try: # Send notification - don't wait for/expect a response await self._http_client.post( url, json=notification, headers=request_headers ) # Note: We don't check response for notifications except httpx.HTTPError: # Notifications may timeout or fail, which is acceptable pass async def _async_connect_http(self): """Async connection initialization for HTTP transport.""" # Create HTTP client self._http_client = httpx.AsyncClient(timeout=self.timeout) # Send initialize request init_params = { "protocolVersion": "2024-11-05", "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, "clientInfo": {"name": "aisuite-mcp-client", "version": "1.0.0"}, } await self._send_http_request("initialize", init_params) # Send initialized notification (required by MCP spec) await self._send_notification("notifications/initialized") # List available tools tools_result = await self._send_http_request("tools/list") # Cache tools self._tools_cache = [ { "name": tool["name"], "description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {}), } for tool in tools_result.get("tools", []) ] def list_tools(self) -> List[Dict[str, Any]]: """ List all available tools from the MCP server. Returns: List of tool schemas in MCP format Example: >>> tools = mcp.list_tools() >>> for tool in tools: ... print(tool['name'], '-', tool['description']) """ if self._tools_cache is None: raise RuntimeError("Not connected to MCP server") return self._tools_cache def get_callable_tools( self, allowed_tools: Optional[List[str]] = None, use_tool_prefix: bool = False, ) -> List[Callable]: """ Get all MCP tools as Python callables compatible with aisuite. This is the primary method for using MCP tools with aisuite. It returns a list of callable wrappers that can be passed directly to the `tools` parameter of `client.chat.completions.create()`. Args: allowed_tools: Optional list of tool names to include. If None, all tools are included. use_tool_prefix: If True, prefix tool names with "{client_name}__" Returns: List of callable tool wrappers Example: >>> # Get all tools >>> mcp_tools = mcp.get_callable_tools() >>> >>> # Get specific tools only >>> mcp_tools = mcp.get_callable_tools(allowed_tools=["read_file"]) >>> >>> # Get tools with name prefixing >>> mcp_tools = mcp.get_callable_tools(use_tool_prefix=True) >>> # Tools will be named "filesystem__read_file", etc. """ all_tools = self.list_tools() # Filter tools if allowed_tools is specified if allowed_tools is not None: all_tools = [t for t in all_tools if t["name"] in allowed_tools] # Create wrappers wrappers = [] for tool in all_tools: wrapper = create_mcp_tool_wrapper(self, tool["name"], tool) # Apply prefix if requested if use_tool_prefix: original_name = wrapper.__name__ wrapper.__name__ = f"{self.name}__{original_name}" wrappers.append(wrapper) return wrappers def get_tool(self, tool_name: str) -> Optional[Callable]: """ Get a specific MCP tool by name as a Python callable. Args: tool_name: Name of the tool to retrieve Returns: Callable wrapper for the tool, or None if not found Example: >>> read_file = mcp.get_tool("read_file") >>> write_file = mcp.get_tool("write_file") >>> tools = [read_file, write_file] """ tools = self.list_tools() for tool in tools: if tool["name"] == tool_name: return create_mcp_tool_wrapper(self, tool_name, tool) return None def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: """ Execute an MCP tool call. This method is called by MCPToolWrapper when the LLM requests a tool. It handles the async MCP protocol communication and returns the result. Automatically routes to the appropriate transport (stdio or HTTP). Args: tool_name: Name of the tool to call arguments: Tool arguments as a dictionary Returns: The result from the MCP tool execution Raises: RuntimeError: If not connected or tool call fails """ # Detect transport type and route to appropriate method if hasattr(self, "_http_client") and self._http_client is not None: # HTTP transport if self._http_client is None: raise RuntimeError("Not connected to MCP server (HTTP)") result = self._event_loop.run_until_complete( self._async_call_tool_http(tool_name, arguments) ) else: # Stdio transport if self._session is None: raise RuntimeError("Not connected to MCP server (stdio)") result = self._event_loop.run_until_complete( self._async_call_tool(tool_name, arguments) ) return result async def _async_call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: """ Async implementation of tool calling for stdio transport. Args: tool_name: Name of the tool arguments: Tool arguments Returns: Tool execution result """ result = await self._session.call_tool(tool_name, arguments) # Extract content from MCP result # MCP returns results in various formats, we try to extract the most useful content if hasattr(result, "content"): if isinstance(result.content, list) and len(result.content) > 0: # Get first content item content_item = result.content[0] if hasattr(content_item, "text"): return content_item.text elif hasattr(content_item, "data"): return content_item.data return str(content_item) return result.content # If no content attribute, return the whole result return str(result) async def _async_call_tool_http( self, tool_name: str, arguments: Dict[str, Any] ) -> Any: """ Async implementation of tool calling for HTTP transport. Args: tool_name: Name of the tool arguments: Tool arguments Returns: Tool execution result """ params = {"name": tool_name, "arguments": arguments} result = await self._send_http_request("tools/call", params) # Extract content from MCP result (HTTP format) # Similar to stdio, but result is already a dict if "content" in result: content = result["content"] if isinstance(content, list) and len(content) > 0: # Get first content item content_item = content[0] if isinstance(content_item, dict): if "text" in content_item: return content_item["text"] elif "data" in content_item: return content_item["data"] return str(content_item) return content # If no content field, return the whole result return json.dumps(result) def close(self): """ Close the connection to the MCP server. Works for both stdio and HTTP transports. It's recommended to use the MCPClient as a context manager to ensure proper cleanup, but this method can be called manually if needed. Example: >>> mcp = MCPClient(command="npx", args=["server"]) >>> try: ... # Use mcp ... pass ... finally: ... mcp.close() """ # Check if we need to cleanup (either stdio or HTTP) needs_cleanup = (hasattr(self, "_session") and self._session is not None) or ( hasattr(self, "_http_client") and self._http_client is not None ) if needs_cleanup: self._event_loop.run_until_complete(self._async_close()) async def _async_close(self): """Async cleanup for both stdio and HTTP transports.""" # Cleanup stdio transport try: if hasattr(self, "_session") and self._session: await self._session.__aexit__(None, None, None) except RuntimeError as e: # Suppress anyio cancel scope errors that occur in Jupyter/nest_asyncio environments # This is a known incompatibility between nest_asyncio and anyio task groups if "cancel scope" not in str(e).lower(): raise except Exception: pass # Ignore other errors during session cleanup try: if hasattr(self, "_stdio_context") and self._stdio_context: await self._stdio_context.__aexit__(None, None, None) except RuntimeError as e: # Suppress anyio cancel scope errors that occur in Jupyter/nest_asyncio environments # This is a known incompatibility between nest_asyncio and anyio task groups if "cancel scope" not in str(e).lower(): raise except Exception: pass # Ignore other errors during stdio cleanup # Cleanup HTTP transport try: if hasattr(self, "_http_client") and self._http_client: await self._http_client.aclose() except Exception: pass # Ignore errors during HTTP client cleanup def __enter__(self): """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() return False def __repr__(self) -> str: """String representation.""" num_tools = len(self._tools_cache) if self._tools_cache else 0 if hasattr(self, "server_url"): return f"MCPClient(server_url={self.server_url!r}, tools={num_tools})" else: return ( f"MCPClient(command={self.server_params.command!r}, tools={num_tools})" ) ================================================ FILE: aisuite/mcp/config.py ================================================ """ MCP configuration validation and normalization. This module provides utilities for validating and normalizing MCP tool configuration dictionaries passed to aisuite's chat completion API. """ from typing import Any, Dict, List, Literal, Optional, TypedDict class MCPConfig(TypedDict, total=False): """Type definition for MCP tool configuration.""" # Required fields type: Literal["mcp"] name: str # Transport: stdio command: str args: List[str] env: Dict[str, str] cwd: str # Transport: http server_url: str headers: Dict[str, str] # Tool filtering allowed_tools: List[str] # Namespacing use_tool_prefix: bool # Safety limits timeout_seconds: int response_bytes_cap: int # Connection behavior lazy_connect: bool # Default values DEFAULT_TIMEOUT_SECONDS = 30 DEFAULT_RESPONSE_BYTES_CAP = 10 * 1024 * 1024 # 10 MB DEFAULT_USE_TOOL_PREFIX = False DEFAULT_LAZY_CONNECT = False def validate_mcp_config(config: Dict[str, Any]) -> MCPConfig: """ Validate and normalize an MCP tool configuration. This function: 1. Validates required fields are present 2. Auto-detects transport type (stdio vs http) 3. Validates transport-specific required fields 4. Sets defaults for optional fields 5. Returns a normalized config dict Args: config: Raw MCP configuration dictionary Returns: Validated and normalized MCP configuration Raises: ValueError: If configuration is invalid Example: >>> config = { ... "type": "mcp", ... "name": "filesystem", ... "command": "npx", ... "args": ["-y", "@modelcontextprotocol/server-filesystem", "/docs"] ... } >>> validated = validate_mcp_config(config) >>> validated['timeout_seconds'] 30 """ # Check type field if config.get("type") != "mcp": raise ValueError(f"Invalid config type: {config.get('type')}. Expected 'mcp'") # Check name field (required) if "name" not in config: raise ValueError( "MCP config must have 'name' field. " "Example: {'type': 'mcp', 'name': 'my_server', ...}" ) name = config["name"] if not isinstance(name, str) or not name.strip(): raise ValueError(f"MCP 'name' must be a non-empty string, got: {name}") # Auto-detect transport type has_stdio = "command" in config has_http = "server_url" in config if not (has_stdio ^ has_http): raise ValueError( "MCP config must have either 'command' or 'server_url'." "Use one or the other to specify transport type." ) # Validate stdio transport if has_stdio: if not isinstance(config["command"], str): raise ValueError( f"MCP 'command' must be a string, got: {type(config['command'])}" ) # args is optional but should be a list if present if "args" in config and not isinstance(config["args"], list): raise ValueError(f"MCP 'args' must be a list, got: {type(config['args'])}") # env is optional but should be a dict if present if "env" in config and not isinstance(config["env"], dict): raise ValueError(f"MCP 'env' must be a dict, got: {type(config['env'])}") # Validate http transport if has_http: if not isinstance(config["server_url"], str): raise ValueError( f"MCP 'server_url' must be a string, got: {type(config['server_url'])}" ) # Validate URL format server_url = config["server_url"] if not (server_url.startswith("http://") or server_url.startswith("https://")): raise ValueError( f"MCP 'server_url' must start with http:// or https://, got: {server_url}" ) # headers is optional but should be a dict if present if "headers" in config and not isinstance(config["headers"], dict): raise ValueError( f"MCP 'headers' must be a dict, got: {type(config['headers'])}" ) # timeout is optional but should be a number if present if "timeout" in config: if not isinstance(config["timeout"], (int, float)): raise ValueError( f"MCP 'timeout' must be a number, got: {type(config['timeout'])}" ) if config["timeout"] <= 0: raise ValueError( f"MCP 'timeout' must be positive, got: {config['timeout']}" ) # Validate optional fields if "allowed_tools" in config: if not isinstance(config["allowed_tools"], list): raise ValueError( f"MCP 'allowed_tools' must be a list, got: {type(config['allowed_tools'])}" ) if not all(isinstance(t, str) for t in config["allowed_tools"]): raise ValueError("MCP 'allowed_tools' must be a list of strings") if "use_tool_prefix" in config: if not isinstance(config["use_tool_prefix"], bool): raise ValueError( f"MCP 'use_tool_prefix' must be a boolean, got: {type(config['use_tool_prefix'])}" ) if "timeout_seconds" in config: if not isinstance(config["timeout_seconds"], (int, float)): raise ValueError( f"MCP 'timeout_seconds' must be a number, got: {type(config['timeout_seconds'])}" ) if config["timeout_seconds"] <= 0: raise ValueError( f"MCP 'timeout_seconds' must be positive, got: {config['timeout_seconds']}" ) if "response_bytes_cap" in config: if not isinstance(config["response_bytes_cap"], int): raise ValueError( f"MCP 'response_bytes_cap' must be an integer, got: {type(config['response_bytes_cap'])}" ) if config["response_bytes_cap"] <= 0: raise ValueError( f"MCP 'response_bytes_cap' must be positive, got: {config['response_bytes_cap']}" ) # Create normalized config with defaults normalized: MCPConfig = { "type": "mcp", "name": config["name"], } # Copy transport fields if has_stdio: normalized["command"] = config["command"] normalized["args"] = config.get("args", []) if "env" in config: normalized["env"] = config["env"] if "cwd" in config: normalized["cwd"] = config["cwd"] else: # has_http normalized["server_url"] = config["server_url"] if "headers" in config: normalized["headers"] = config["headers"] if "timeout" in config: normalized["timeout"] = config["timeout"] # Copy optional fields with defaults if "allowed_tools" in config: normalized["allowed_tools"] = config["allowed_tools"] normalized["use_tool_prefix"] = config.get( "use_tool_prefix", DEFAULT_USE_TOOL_PREFIX ) normalized["timeout_seconds"] = config.get( "timeout_seconds", DEFAULT_TIMEOUT_SECONDS ) normalized["response_bytes_cap"] = config.get( "response_bytes_cap", DEFAULT_RESPONSE_BYTES_CAP ) normalized["lazy_connect"] = config.get("lazy_connect", DEFAULT_LAZY_CONNECT) return normalized def is_mcp_config(obj: Any) -> bool: """ Check if an object is an MCP config dictionary. Args: obj: Object to check Returns: True if obj is a dict with type="mcp", False otherwise Example: >>> is_mcp_config({"type": "mcp", "name": "test"}) True >>> is_mcp_config(lambda: None) False """ return isinstance(obj, dict) and obj.get("type") == "mcp" def get_transport_type(config: MCPConfig) -> Literal["stdio", "http"]: """ Determine the transport type from a validated MCP config. Args: config: Validated MCP configuration Returns: "stdio" or "http" """ if "command" in config: return "stdio" else: return "http" ================================================ FILE: aisuite/mcp/schema_converter.py ================================================ """ Schema conversion utilities for MCP tools. This module provides functionality to convert MCP JSON Schema tool definitions to Python type annotations that are compatible with aisuite's existing Tools class. """ from typing import Any, Dict, List, Optional, Union, get_args, get_origin import inspect def json_schema_to_python_type(schema: Dict[str, Any]) -> type: """ Convert a JSON Schema type definition to a Python type annotation. Args: schema: JSON Schema type definition (e.g., {"type": "string"}) Returns: Python type annotation (e.g., str, int, List[str], etc.) """ schema_type = schema.get("type") # Handle null/None if schema_type == "null": return type(None) # Handle basic types type_mapping = { "string": str, "number": float, "integer": int, "boolean": bool, "object": dict, "array": list, } if schema_type in type_mapping: base_type = type_mapping[schema_type] # Handle arrays with item type if schema_type == "array" and "items" in schema: item_type = json_schema_to_python_type(schema["items"]) return List[item_type] return base_type # Handle anyOf/oneOf (union types) if "anyOf" in schema or "oneOf" in schema: union_schemas = schema.get("anyOf", schema.get("oneOf", [])) types = [json_schema_to_python_type(s) for s in union_schemas] if len(types) == 1: return types[0] return Union[tuple(types)] # Default to Any if we can't determine the type return Any def mcp_schema_to_annotations(input_schema: Dict[str, Any]) -> Dict[str, type]: """ Convert MCP tool input schema to Python type annotations. MCP tools use JSON Schema for their input parameters. This function converts those schemas to Python type annotations that can be used by aisuite's Tools class. Args: input_schema: MCP tool input schema (JSON Schema format) Returns: Dictionary mapping parameter names to Python types Example: >>> schema = { ... "type": "object", ... "properties": { ... "location": {"type": "string"}, ... "count": {"type": "integer"} ... }, ... "required": ["location"] ... } >>> annotations = mcp_schema_to_annotations(schema) >>> annotations {'location': , 'count': typing.Optional[int]} """ annotations = {} if input_schema.get("type") != "object": return annotations properties = input_schema.get("properties", {}) required = input_schema.get("required", []) for param_name, param_schema in properties.items(): param_type = json_schema_to_python_type(param_schema) # Make optional if not in required list if param_name not in required: param_type = Optional[param_type] annotations[param_name] = param_type return annotations def create_function_signature( func_name: str, annotations: Dict[str, type], docstring: Optional[str] = None ) -> inspect.Signature: """ Create a function signature from parameter annotations. Args: func_name: Name of the function annotations: Dictionary mapping parameter names to types docstring: Optional docstring for the function Returns: inspect.Signature object """ parameters = [] for param_name, param_type in annotations.items(): # Check if it's an Optional type if get_origin(param_type) is Union: args = get_args(param_type) if type(None) in args: # It's Optional, set default to None parameters.append( inspect.Parameter( param_name, inspect.Parameter.KEYWORD_ONLY, default=None, annotation=param_type, ) ) else: parameters.append( inspect.Parameter( param_name, inspect.Parameter.KEYWORD_ONLY, annotation=param_type, ) ) else: # Required parameter parameters.append( inspect.Parameter( param_name, inspect.Parameter.KEYWORD_ONLY, annotation=param_type, ) ) return inspect.Signature(parameters) def extract_parameter_descriptions(input_schema: Dict[str, Any]) -> Dict[str, str]: """ Extract parameter descriptions from MCP schema. Args: input_schema: MCP tool input schema Returns: Dictionary mapping parameter names to their descriptions """ descriptions = {} properties = input_schema.get("properties", {}) for param_name, param_schema in properties.items(): if "description" in param_schema: descriptions[param_name] = param_schema["description"] return descriptions def build_docstring( tool_description: str, parameter_descriptions: Dict[str, str] ) -> str: """ Build a Python docstring from MCP tool description and parameter descriptions. Args: tool_description: Overall description of the tool parameter_descriptions: Dictionary of parameter descriptions Returns: Formatted docstring """ lines = [tool_description, ""] if parameter_descriptions: lines.append("Args:") for param_name, param_desc in parameter_descriptions.items(): lines.append(f" {param_name}: {param_desc}") return "\n".join(lines) ================================================ FILE: aisuite/mcp/tool_wrapper.py ================================================ """ MCP Tool Wrapper for aisuite. This module provides the MCPToolWrapper class, which creates Python callable wrappers around MCP tools that are compatible with aisuite's existing tool calling infrastructure. """ from typing import Any, Callable, Dict, Optional import asyncio import inspect from .schema_converter import ( mcp_schema_to_annotations, extract_parameter_descriptions, build_docstring, ) class MCPToolWrapper: """ A callable wrapper around an MCP tool that makes it compatible with aisuite. This class wraps an MCP tool and exposes it as a Python callable with proper type annotations and docstrings that aisuite's Tools class can inspect and use. The wrapper sets the following attributes that aisuite's Tools class reads: - __name__: The tool name - __doc__: The tool description and parameter documentation - __annotations__: Python type annotations for parameters When called, the wrapper executes the MCP tool via the MCP protocol. Example: >>> wrapper = MCPToolWrapper(mcp_client, "read_file", tool_schema) >>> result = wrapper(path="/path/to/file") """ def __init__( self, mcp_client: "MCPClient", # Forward reference to avoid circular import tool_name: str, tool_schema: Dict[str, Any], ): """ Initialize the MCP tool wrapper. Args: mcp_client: The MCPClient instance that manages the connection tool_name: Name of the MCP tool tool_schema: MCP tool schema definition """ self.mcp_client = mcp_client self.tool_name = tool_name self.schema = tool_schema # Set attributes that aisuite's Tools class will inspect self.__name__ = tool_name # Build docstring from MCP schema description = tool_schema.get("description", "") input_schema = tool_schema.get("inputSchema", {}) param_descriptions = extract_parameter_descriptions(input_schema) self.__doc__ = build_docstring(description, param_descriptions) # Convert MCP JSON Schema to Python type annotations self.__annotations__ = mcp_schema_to_annotations(input_schema) # Create a proper signature for inspect.signature() to read # This allows aisuite's Tools class to introspect the parameters self.__signature__ = self._create_signature(input_schema) # Store the original MCP inputSchema for direct use by Tools class # This avoids lossy round-trip conversion through Python type annotations # and preserves all JSON Schema details (arrays, nested objects, etc.) self.__mcp_input_schema__ = input_schema def _create_signature(self, input_schema: Dict[str, Any]) -> inspect.Signature: """ Create a signature for this wrapper based on MCP tool schema. This allows inspect.signature() to see the proper parameters with type annotations, rather than just **kwargs. """ properties = input_schema.get("properties", {}) required = input_schema.get("required", []) parameters = [] for param_name, annotation in self.__annotations__.items(): # Create parameter with annotation and default if param_name in required: # Required parameter (no default) param = inspect.Parameter( param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation, ) else: # Optional parameter (with None default) param = inspect.Parameter( param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None, annotation=annotation, ) parameters.append(param) return inspect.Signature(parameters, return_annotation=Any) def __call__(self, **kwargs) -> Any: """ Execute the MCP tool with the given arguments. This method is called by aisuite's tool execution loop when the LLM requests this tool. Args: **kwargs: Tool arguments as keyword arguments Returns: The result from the MCP tool execution """ # Filter out None values - only pass parameters that have actual values # This prevents passing null to MCP tools that expect specific types # (e.g., a tool expecting number won't accept null, it wants the param omitted) filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} # Call the MCP client's tool execution method # The MCP client handles the async MCP protocol communication return self.mcp_client.call_tool(self.tool_name, filtered_kwargs) def __repr__(self) -> str: """Return a string representation of the wrapper.""" return f"MCPToolWrapper(name={self.tool_name!r})" def create_mcp_tool_wrapper( mcp_client: "MCPClient", tool_name: str, tool_schema: Dict[str, Any], ) -> Callable: """ Factory function to create an MCP tool wrapper. Args: mcp_client: The MCPClient instance tool_name: Name of the tool tool_schema: MCP tool schema Returns: Callable wrapper for the MCP tool """ return MCPToolWrapper(mcp_client, tool_name, tool_schema) ================================================ FILE: aisuite/provider.py ================================================ from abc import ABC, abstractmethod from pathlib import Path import importlib import os import functools from typing import Union, BinaryIO, Optional class LLMError(Exception): """Custom exception for LLM errors.""" def __init__(self, message): super().__init__(message) class ASRError(Exception): """Custom exception for ASR errors.""" def __init__(self, message): super().__init__(message) class Provider(ABC): def __init__(self): """Initialize provider with optional audio functionality.""" self.audio: Optional[Audio] = None @abstractmethod def chat_completions_create(self, model, messages): """Abstract method for chat completion calls, to be implemented by each provider.""" pass class ProviderFactory: """Factory to dynamically load provider instances based on naming conventions.""" PROVIDERS_DIR = Path(__file__).parent / "providers" @classmethod def create_provider(cls, provider_key, config): """Dynamically load and create an instance of a provider based on the naming convention.""" # Convert provider_key to the expected module and class names provider_class_name = f"{provider_key.capitalize()}Provider" provider_module_name = f"{provider_key}_provider" module_path = f"aisuite.providers.{provider_module_name}" # Lazily load the module try: module = importlib.import_module(module_path) except ImportError as e: raise ImportError( f"Could not import module {module_path}: {str(e)}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" ) # Instantiate the provider class provider_class = getattr(module, provider_class_name) return provider_class(**config) @classmethod @functools.cache def get_supported_providers(cls): """List all supported provider names based on files present in the providers directory.""" provider_files = Path(cls.PROVIDERS_DIR).glob("*_provider.py") return {file.stem.replace("_provider", "") for file in provider_files} class Audio: """Base class for all audio functionality.""" def __init__(self): self.transcriptions: Optional["Audio.Transcription"] = None class Transcription(ABC): """Base class for audio transcription functionality.""" def create( self, model: str, file: Union[str, BinaryIO], options=None, **kwargs, ): """Create audio transcription.""" raise NotImplementedError("Transcription not supported by this provider") async def create_stream_output( self, model: str, file: Union[str, BinaryIO], options=None, **kwargs, ): """Create streaming audio transcription.""" raise NotImplementedError( "Streaming transcription not supported by this provider" ) ================================================ FILE: aisuite/providers/__init__.py ================================================ ================================================ FILE: aisuite/providers/anthropic_provider.py ================================================ # Anthropic provider # Links: # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use import anthropic import json from aisuite.provider import Provider from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import ( Message, ChatCompletionMessageToolCall, Function, CompletionUsage, PromptTokensDetails, ) # Define a constant for the default max_tokens value DEFAULT_MAX_TOKENS = 4096 class AnthropicMessageConverter: # Role constants ROLE_USER = "user" ROLE_ASSISTANT = "assistant" ROLE_TOOL = "tool" ROLE_SYSTEM = "system" # Finish reason mapping FINISH_REASON_MAPPING = { "end_turn": "stop", "max_tokens": "length", "tool_use": "tool_calls", } def convert_request(self, messages): """Convert framework messages to Anthropic format.""" system_message = self._extract_system_message(messages) converted_messages = [self._convert_single_message(msg) for msg in messages] return system_message, converted_messages def convert_response(self, response): """Normalize the response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() normalized_response.choices[0].finish_reason = self._get_finish_reason(response) normalized_response.usage = self._get_completion_usage(response) normalized_response.choices[0].message = self._get_message(response) return normalized_response def _convert_single_message(self, msg): """Convert a single message to Anthropic format.""" if isinstance(msg, dict): return self._convert_dict_message(msg) return self._convert_message_object(msg) def _convert_dict_message(self, msg): """Convert a dictionary message to Anthropic format.""" if msg["role"] == self.ROLE_TOOL: return self._create_tool_result_message(msg["tool_call_id"], msg["content"]) elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg: return self._create_assistant_tool_message( msg["content"], msg["tool_calls"] ) return {"role": msg["role"], "content": msg["content"]} def _convert_message_object(self, msg): """Convert a Message object to Anthropic format.""" if msg.role == self.ROLE_TOOL: return self._create_tool_result_message(msg.tool_call_id, msg.content) elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls: return self._create_assistant_tool_message(msg.content, msg.tool_calls) return {"role": msg.role, "content": msg.content} def _create_tool_result_message(self, tool_call_id, content): """Create a tool result message in Anthropic format.""" return { "role": self.ROLE_USER, "content": [ { "type": "tool_result", "tool_use_id": tool_call_id, "content": content, } ], } def _create_assistant_tool_message(self, content, tool_calls): """Create an assistant message with tool calls in Anthropic format.""" message_content = [] if content: message_content.append({"type": "text", "text": content}) for tool_call in tool_calls: tool_input = ( tool_call["function"]["arguments"] if isinstance(tool_call, dict) else tool_call.function.arguments ) message_content.append( { "type": "tool_use", "id": ( tool_call["id"] if isinstance(tool_call, dict) else tool_call.id ), "name": ( tool_call["function"]["name"] if isinstance(tool_call, dict) else tool_call.function.name ), "input": json.loads(tool_input), } ) return {"role": self.ROLE_ASSISTANT, "content": message_content} def _extract_system_message(self, messages): """Extract system message if present, otherwise return empty list.""" # TODO: This is a temporary solution to extract the system message. # User can pass multiple system messages, which can mingled with other messages. # This needs to be fixed to handle this case. if messages and messages[0]["role"] == "system": system_message = messages[0]["content"] messages.pop(0) return system_message return [] def _get_finish_reason(self, response): """Get the normalized finish reason.""" return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop") def _get_completion_usage(self, response): """Get the usage statistics.""" return CompletionUsage( completion_tokens=response.usage.output_tokens, prompt_tokens=response.usage.input_tokens, total_tokens=response.usage.input_tokens + response.usage.output_tokens, prompt_tokens_details=PromptTokensDetails( cached_tokens=response.usage.cache_read_input_tokens, ), ) def _get_message(self, response): """Get the appropriate message based on response type.""" # Check if response contains any tool use blocks (regardless of stop_reason) has_tool_use = any(content.type == "tool_use" for content in response.content) if has_tool_use: tool_message = self.convert_response_with_tool_use(response) if tool_message: return tool_message # Safely extract text content from any position in content blocks text_content = next( (content.text for content in response.content if content.type == "text"), "", ) return Message( content=text_content or None, role="assistant", tool_calls=None, refusal=None, ) def convert_response_with_tool_use(self, response): """Convert Anthropic tool use response to the framework's format.""" tool_call = next( (content for content in response.content if content.type == "tool_use"), None, ) if tool_call: function = Function( name=tool_call.name, arguments=json.dumps(tool_call.input) ) tool_call_obj = ChatCompletionMessageToolCall( id=tool_call.id, function=function, type="function" ) text_content = next( ( content.text for content in response.content if content.type == "text" ), "", ) return Message( content=text_content or None, tool_calls=[tool_call_obj] if tool_call else None, role="assistant", refusal=None, ) return None def convert_tool_spec(self, openai_tools): """Convert OpenAI tool specification to Anthropic format.""" anthropic_tools = [] for tool in openai_tools: if tool.get("type") != "function": continue function = tool["function"] anthropic_tool = { "name": function["name"], "description": function["description"], "input_schema": { "type": "object", "properties": function["parameters"]["properties"], "required": function["parameters"].get("required", []), }, } anthropic_tools.append(anthropic_tool) return anthropic_tools class AnthropicProvider(Provider): def __init__(self, **config): """Initialize the Anthropic provider with the given configuration.""" self.client = anthropic.Anthropic(**config) self.converter = AnthropicMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """Create a chat completion using the Anthropic API.""" kwargs = self._prepare_kwargs(kwargs) system_message, converted_messages = self.converter.convert_request(messages) response = self.client.messages.create( model=model, system=system_message, messages=converted_messages, **kwargs ) return self.converter.convert_response(response) def _prepare_kwargs(self, kwargs): """Prepare kwargs for the API call.""" kwargs = kwargs.copy() kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) if "tools" in kwargs: kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) return kwargs ================================================ FILE: aisuite/providers/aws_provider.py ================================================ """AWS Bedrock provider for the aisuite.""" import os import json from typing import List, Dict, Any, Tuple, Optional import boto3 import botocore from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, CompletionUsage # pylint: disable=too-few-public-methods class BedrockConfig: """Configuration for the AWS Bedrock provider.""" INFERENCE_PARAMETERS = ["maxTokens", "temperature", "topP", "stopSequences"] def __init__(self, **config): """Initialize the BedrockConfig.""" self.region_name = config.get( "region_name", os.getenv("AWS_REGION", "us-west-2") ) def create_client(self): """Create a Bedrock runtime client.""" return boto3.client("bedrock-runtime", region_name=self.region_name) # AWS Bedrock API Example - # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use-inference-call.html # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use-examples.html class BedrockMessageConverter: """Converts messages between OpenAI and AWS Bedrock formats.""" @staticmethod def convert_request( messages: List[Dict[str, Any]], ) -> Tuple[List[Dict], List[Dict]]: """Convert messages to AWS Bedrock format.""" # Convert all messages to dicts if they're Message objects messages = [ message.model_dump() if hasattr(message, "model_dump") else message for message in messages ] # Handle system message system_message = [] if messages and messages[0]["role"] == "system": system_message = [{"text": messages[0]["content"]}] messages = messages[1:] formatted_messages = [] for message in messages: # Skip any additional system messages if message["role"] == "system": continue if message["role"] == "tool": bedrock_message = BedrockMessageConverter.convert_tool_result(message) if bedrock_message: formatted_messages.append(bedrock_message) elif message["role"] == "assistant": bedrock_message = BedrockMessageConverter.convert_assistant(message) if bedrock_message: formatted_messages.append(bedrock_message) else: # user messages formatted_messages.append( { "role": message["role"], "content": [{"text": message["content"]}], } ) return system_message, formatted_messages @staticmethod def convert_response_tool_call( response: Dict[str, Any], ) -> Optional[Dict[str, Any]]: """Convert AWS Bedrock tool call response to OpenAI format.""" if response.get("stopReason") != "tool_use": return None tool_calls = [] for content in response["output"]["message"]["content"]: if "toolUse" in content: tool = content["toolUse"] tool_calls.append( { "type": "function", "id": tool["toolUseId"], "function": { "name": tool["name"], "arguments": json.dumps(tool["input"]), }, } ) if not tool_calls: return None return { "role": "assistant", "content": None, "tool_calls": tool_calls, "refusal": None, } @staticmethod def convert_tool_result(message: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert OpenAI tool result format to AWS Bedrock format.""" if message["role"] != "tool" or "content" not in message: return None tool_call_id = message.get("tool_call_id") if not tool_call_id: raise LLMError("Tool result message must include tool_call_id") try: content_json = json.loads(message["content"]) content = [{"json": content_json}] except json.JSONDecodeError: content = [{"text": message["content"]}] return { "role": "user", "content": [ {"toolResult": {"toolUseId": tool_call_id, "content": content}} ], } @staticmethod def convert_assistant(message: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert OpenAI assistant format to AWS Bedrock format.""" if message["role"] != "assistant": return None content = [] if message.get("content"): content.append({"text": message["content"]}) if message.get("tool_calls"): for tool_call in message["tool_calls"]: if tool_call["type"] == "function": try: input_json = json.loads(tool_call["function"]["arguments"]) except json.JSONDecodeError: input_json = tool_call["function"]["arguments"] content.append( { "toolUse": { "toolUseId": tool_call["id"], "name": tool_call["function"]["name"], "input": input_json, } } ) return {"role": "assistant", "content": content} if content else None @staticmethod def convert_response(response: Dict[str, Any]) -> ChatCompletionResponse: """Normalize the response from the Bedrock API to match OpenAI's response format.""" norm_response = ChatCompletionResponse() # Check if the model is requesting tool use if response.get("stopReason") == "tool_use": tool_message = BedrockMessageConverter.convert_response_tool_call(response) if tool_message: norm_response.choices[0].message = Message(**tool_message) norm_response.choices[0].finish_reason = "tool_calls" return norm_response # Handle regular text response norm_response.choices[0].message.content = response["output"]["message"][ "content" ][0]["text"] # Map Bedrock stopReason to OpenAI finish_reason stop_reason = response.get("stopReason") if stop_reason == "complete": norm_response.choices[0].finish_reason = "stop" elif stop_reason == "max_tokens": norm_response.choices[0].finish_reason = "length" else: norm_response.choices[0].finish_reason = stop_reason # Conditionally parse usage data if it exists. if usage_data := response.get("usage"): norm_response.usage = BedrockMessageConverter.get_completion_usage( usage_data ) return norm_response @staticmethod def get_completion_usage(usage_data: dict): """Get the usage statistics from a usage data dictionary.""" return CompletionUsage( completion_tokens=usage_data.get("outputTokens"), prompt_tokens=usage_data.get("inputTokens"), total_tokens=usage_data.get("totalTokens"), ) class AwsProvider(Provider): """Provider for AWS Bedrock.""" def __init__(self, **config): """Initialize the AWS Bedrock provider with the given configuration.""" self.config = BedrockConfig(**config) self.client = self.config.create_client() self.transformer = BedrockMessageConverter() def convert_response(self, response: Dict[str, Any]) -> ChatCompletionResponse: """Normalize the response from the Bedrock API to match OpenAI's response format.""" return self.transformer.convert_response(response) def _convert_tool_spec(self, kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert tool specifications to Bedrock format.""" if "tools" not in kwargs: return None tool_config = { "tools": [ { "toolSpec": { "name": tool["function"]["name"], "description": tool["function"].get("description", " "), "inputSchema": {"json": tool["function"]["parameters"]}, } } for tool in kwargs["tools"] ] } return tool_config def _prepare_request_config(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Prepare the configuration for the Bedrock API request.""" # Convert tools and remove from kwargs tool_config = self._convert_tool_spec(kwargs) kwargs.pop("tools", None) # Remove tools from kwargs if present inference_config = { key: kwargs[key] for key in BedrockConfig.INFERENCE_PARAMETERS if key in kwargs } additional_fields = { key: value for key, value in kwargs.items() if key not in BedrockConfig.INFERENCE_PARAMETERS } request_config = { "inferenceConfig": inference_config, "additionalModelRequestFields": additional_fields, } if tool_config is not None: request_config["toolConfig"] = tool_config return request_config def chat_completions_create( self, model: str, messages: List[Dict[str, Any]], **kwargs ) -> ChatCompletionResponse: """Create a chat completion request to AWS Bedrock.""" system_message, formatted_messages = self.transformer.convert_request(messages) request_config = self._prepare_request_config(kwargs) try: response = self.client.converse( modelId=model, messages=formatted_messages, system=system_message, **request_config, ) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "ValidationException": error_message = e.response["Error"]["Message"] raise LLMError(error_message) from e raise return self.convert_response(response) ================================================ FILE: aisuite/providers/azure_provider.py ================================================ import urllib.request import json import os from aisuite.provider import Provider from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function # Azure provider is based on the documentation here - # https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-api?view=azureml-api-2&source=recommendations&tabs=python # Azure AI Model Inference API is used. # From the documentation - # """ # The Azure AI Model Inference is an API that exposes a common set of capabilities for foundational models # and that can be used by developers to consume predictions from a diverse set of models in a uniform and consistent way. # Developers can talk with different models deployed in Azure AI Foundry portal without changing the underlying code they are using. # # The Azure AI Model Inference API is available in the following models: # # Models deployed to serverless API endpoints: # Cohere Embed V3 family of models # Cohere Command R family of models # Meta Llama 2 chat family of models # Meta Llama 3 instruct family of models # Mistral-Small # Mistral-Large # Jais family of models # Jamba family of models # Phi-3 family of models # # Models deployed to managed inference: # Meta Llama 3 instruct family of models # Phi-3 family of models # Mixtral famility of models # # The API is compatible with Azure OpenAI model deployments. # """ class AzureMessageConverter: @staticmethod def convert_request(messages): """Convert messages to Azure format.""" transformed_messages = [] for message in messages: if isinstance(message, Message): transformed_messages.append(message.model_dump(mode="json")) else: transformed_messages.append(message) return transformed_messages @staticmethod def convert_response(resp_json) -> ChatCompletionResponse: """Normalize the response from the Azure API to match OpenAI's response format.""" completion_response = ChatCompletionResponse() choice = resp_json["choices"][0] message = choice["message"] # Set basic message content completion_response.choices[0].message.content = message.get("content") completion_response.choices[0].message.role = message.get("role", "assistant") # Handle tool calls if present if "tool_calls" in message and message["tool_calls"] is not None: tool_calls = [] for tool_call in message["tool_calls"]: new_tool_call = ChatCompletionMessageToolCall( id=tool_call["id"], type=tool_call["type"], function={ "name": tool_call["function"]["name"], "arguments": tool_call["function"]["arguments"], }, ) tool_calls.append(new_tool_call) completion_response.choices[0].message.tool_calls = tool_calls return completion_response class AzureProvider(Provider): def __init__(self, **config): self.base_url = config.get("base_url") or os.getenv("AZURE_BASE_URL") self.api_key = config.get("api_key") or os.getenv("AZURE_API_KEY") self.api_version = config.get("api_version") or os.getenv("AZURE_API_VERSION") if not self.api_key: raise ValueError("For Azure, api_key is required.") if not self.base_url: raise ValueError( "For Azure, base_url is required. Check your deployment page for a URL like this - https://..models.ai.azure.com" ) self.transformer = AzureMessageConverter() def chat_completions_create(self, model, messages, **kwargs): url = f"{self.base_url}/chat/completions" if self.api_version: url = f"{url}?api-version={self.api_version}" # Remove 'stream' from kwargs if present kwargs.pop("stream", None) # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) # Prepare the request payload data = {"messages": transformed_messages} # Add tools if provided if "tools" in kwargs: data["tools"] = kwargs["tools"] kwargs.pop("tools") # Add tool_choice if provided if "tool_choice" in kwargs: data["tool_choice"] = kwargs["tool_choice"] kwargs.pop("tool_choice") # Add remaining kwargs data.update(kwargs) body = json.dumps(data).encode("utf-8") headers = {"Content-Type": "application/json", "Authorization": self.api_key} try: req = urllib.request.Request(url, body, headers) with urllib.request.urlopen(req) as response: result = response.read() resp_json = json.loads(result) return self.transformer.convert_response(resp_json) except urllib.error.HTTPError as error: error_message = f"The request failed with status code: {error.code}\n" error_message += f"Headers: {error.info()}\n" error_message += error.read().decode("utf-8", "ignore") raise Exception(error_message) ================================================ FILE: aisuite/providers/cerebras_provider.py ================================================ """Cerebras provider for the aisuite.""" import cerebras.cloud.sdk as cerebras from aisuite.provider import Provider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter class CerebrasMessageConverter(OpenAICompliantMessageConverter): """ Cerebras-specific message converter if needed. """ # pylint: disable=too-few-public-methods class CerebrasProvider(Provider): """Provider for Cerebras.""" def __init__(self, **config): self.client = cerebras.Cerebras(**config) self.transformer = CerebrasMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the Cerebras chat completions endpoint using the official client. """ try: response = self.client.chat.completions.create( model=model, messages=messages, **kwargs, # Pass any additional arguments to the Cerebras API. ) return self.transformer.convert_response(response.model_dump()) # Re-raise Cerebras API-specific exceptions. except cerebras.PermissionDeniedError: raise except cerebras.AuthenticationError: raise except cerebras.RateLimitError: raise # Wrap all other exceptions in LLMError. except Exception as e: raise LLMError(f"An error occurred: {e}") from e ================================================ FILE: aisuite/providers/cohere_provider.py ================================================ import os import cohere import json from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function from aisuite.provider import Provider, LLMError class CohereMessageConverter: """ Cohere-specific message converter """ def convert_request(self, messages): """Convert framework messages to Cohere format.""" converted_messages = [] for message in messages: if isinstance(message, dict): role = message.get("role") content = message.get("content") tool_calls = message.get("tool_calls") tool_plan = message.get("tool_plan") else: role = message.role content = message.content tool_calls = message.tool_calls tool_plan = getattr(message, "tool_plan", None) # Convert to Cohere's format if role == "tool": # Handle tool response messages converted_message = { "role": role, "tool_call_id": ( message.get("tool_call_id") if isinstance(message, dict) else message.tool_call_id ), "content": self._convert_tool_content(content), } elif role == "assistant" and tool_calls: # Handle assistant messages with tool calls converted_message = { "role": role, "tool_calls": [ { "id": tc.id if not isinstance(tc, dict) else tc["id"], "function": { "name": ( tc.function.name if not isinstance(tc, dict) else tc["function"]["name"] ), "arguments": ( tc.function.arguments if not isinstance(tc, dict) else tc["function"]["arguments"] ), }, "type": "function", } for tc in tool_calls ], "tool_plan": tool_plan, } if content: converted_message["content"] = content else: # Handle regular messages converted_message = {"role": role, "content": content} converted_messages.append(converted_message) return converted_messages def _convert_tool_content(self, content): """Convert tool response content to Cohere's expected format.""" if isinstance(content, str): try: # Try to parse as JSON first data = json.loads(content) return [{"type": "document", "document": {"data": json.dumps(data)}}] except json.JSONDecodeError: # If not JSON, return as plain text return content elif isinstance(content, list): # If content is already in Cohere's format, return as is return content else: # For other types, convert to string return str(content) @staticmethod def convert_response(response_data) -> ChatCompletionResponse: """Convert Cohere's response to our standard format.""" normalized_response = ChatCompletionResponse() # Set usage information normalized_response.usage = { "prompt_tokens": response_data.usage.tokens.input_tokens, "completion_tokens": response_data.usage.tokens.output_tokens, "total_tokens": response_data.usage.tokens.input_tokens + response_data.usage.tokens.output_tokens, } # Handle tool calls if response_data.finish_reason == "TOOL_CALL": tool_call = response_data.message.tool_calls[0] function = Function( name=tool_call.function.name, arguments=tool_call.function.arguments ) tool_call_obj = ChatCompletionMessageToolCall( id=tool_call.id, function=function, type="function" ) normalized_response.choices[0].message = Message( content=response_data.message.tool_plan, # Use tool_plan as content tool_calls=[tool_call_obj], role="assistant", refusal=None, ) normalized_response.choices[0].finish_reason = "tool_calls" else: # Handle regular text response normalized_response.choices[0].message.content = ( response_data.message.content[0].text ) normalized_response.choices[0].finish_reason = "stop" return normalized_response class CohereProvider(Provider): def __init__(self, **config): """ Initialize the Cohere provider with the given configuration. Pass the entire configuration dictionary to the Cohere client constructor. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("CO_API_KEY")) if not config["api_key"]: raise ValueError( "Cohere API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." ) self.client = cohere.ClientV2(**config) self.transformer = CohereMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to Cohere using the official client. """ try: # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) # Make the request to Cohere response = self.client.chat( model=model, messages=transformed_messages, **kwargs ) return self.transformer.convert_response(response) except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/providers/deepgram_provider.py ================================================ import os import json import numpy as np import queue import threading import time from typing import Union, BinaryIO, AsyncGenerator from aisuite.provider import Provider, ASRError, Audio from aisuite.framework.message import ( TranscriptionResult, Segment, Word, Alternative, Channel, StreamingTranscriptionChunk, ) class DeepgramProvider(Provider): """Deepgram ASR provider.""" def __init__(self, **config): """Initialize the Deepgram provider with the given configuration.""" super().__init__() # Ensure API key is provided either in config or via environment variable self.api_key = config.get("api_key") or os.getenv("DEEPGRAM_API_KEY") if not self.api_key: raise ValueError( "Deepgram API key is missing. Please provide it in the config or set the DEEPGRAM_API_KEY environment variable." ) # Initialize Deepgram client (v5.0.0+) try: from deepgram import DeepgramClient self.client = DeepgramClient(api_key=self.api_key) except ImportError: raise ImportError( "Deepgram SDK is required. Install it with: pip install deepgram-sdk" ) # Initialize audio functionality self.audio = DeepgramAudio(self.client) def chat_completions_create(self, model, messages): """Deepgram does not support chat completions.""" raise NotImplementedError( "Deepgram provider only supports audio transcription, not chat completions." ) # Audio Classes class DeepgramAudio(Audio): """Deepgram Audio functionality container.""" def __init__(self, client): super().__init__() self.transcriptions = self.Transcriptions(client) class Transcriptions(Audio.Transcription): """Deepgram Audio Transcriptions functionality.""" def __init__(self, client): self.client = client def create( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> TranscriptionResult: """ Create audio transcription using Deepgram SDK v5. All parameters are already validated and mapped by the Client layer. This is a simple pass-through to the Deepgram API. """ try: # Add model to params and set defaults kwargs["model"] = model kwargs.setdefault("smart_format", True) kwargs.setdefault("punctuate", True) kwargs.setdefault("language", "en") # Get audio bytes audio_bytes = self._prepare_audio_payload(file) # Use v5 API: client.listen.v1.media.transcribe_file() # All parameters passed as kwargs, no PrerecordedOptions needed response = self.client.listen.v1.media.transcribe_file( request=audio_bytes, **kwargs ) # Convert Pydantic model to dict (v5 uses Pydantic v2) if hasattr(response, "model_dump"): response_dict = response.model_dump() elif hasattr(response, "to_dict"): response_dict = response.to_dict() elif hasattr(response, "dict"): response_dict = response.dict() else: response_dict = response return self._parse_deepgram_response(response_dict) except Exception as e: raise ASRError(f"Deepgram transcription error: {e}") from e async def create_stream_output( self, model: str, file: Union[str, BinaryIO], chunk_size_minutes: float = 3.0, **kwargs, ) -> AsyncGenerator[StreamingTranscriptionChunk, None]: """ Create streaming audio transcription using Deepgram SDK v5 with chunked processing. All parameters are already validated and mapped by the Client layer. This implementation handles audio chunking and streaming. """ try: # Load and prepare audio audio_data, sample_rate = await self._load_and_prepare_audio(file) # Calculate chunking strategy duration_seconds = len(audio_data) / sample_rate chunk_duration_seconds = chunk_size_minutes * 60 if duration_seconds <= chunk_duration_seconds: chunks = [audio_data] else: chunk_size_samples = int(chunk_duration_seconds * sample_rate) chunks = [] num_chunks = int(np.ceil(duration_seconds / chunk_duration_seconds)) for i in range(num_chunks): start_sample = i * chunk_size_samples end_sample = min( start_sample + chunk_size_samples, len(audio_data) ) chunks.append(audio_data[start_sample:end_sample]) # Setup API parameters for v5 kwargs["model"] = model kwargs.setdefault("smart_format", "true") kwargs.setdefault("punctuate", "true") kwargs.setdefault("language", "en") kwargs["interim_results"] = ( "true" # Enable interim results for streaming ) # Remove parameters not supported by streaming kwargs.pop("utterances", None) # Add critical audio format parameters (as strings for v5) kwargs["encoding"] = "linear16" # PCM16 format kwargs["sample_rate"] = "16000" # Match our target sample rate kwargs["channels"] = "1" # Mono audio # Use thread-safe queue for cross-thread communication transcript_queue = queue.Queue() connection_closed = threading.Event() def on_message(*args, **message_kwargs): """Handle transcript events""" # Extract result from args or kwargs result = None if len(args) >= 2: result = args[1] elif "result" in message_kwargs: result = message_kwargs["result"] else: return if hasattr(result, "channel") and result.channel.alternatives: alt = result.channel.alternatives[0] if alt.transcript: chunk = StreamingTranscriptionChunk( text=alt.transcript, is_final=getattr(result, "is_final", False), confidence=getattr(alt, "confidence", None), ) transcript_queue.put(chunk) def on_error(*args, **error_kwargs): """Handle error events""" error = None if len(args) >= 2: error = args[1] elif "error" in error_kwargs: error = error_kwargs["error"] if error: transcript_queue.put( ASRError(f"Deepgram streaming error: {error}") ) def on_close(*args, **close_kwargs): """Handle connection close events""" connection_closed.set() # Use v5 streaming API with context manager from deepgram.core.events import EventType async with self.client.listen.v1.connect(**kwargs) as connection: # Register event handlers connection.on(EventType.Transcript, on_message) connection.on(EventType.Error, on_error) connection.on(EventType.Close, on_close) # Send all chunks through connection for audio_chunk in chunks: self._send_audio_chunk(connection, audio_chunk) # Send CloseStream message to signal end close_stream_message = json.dumps({"type": "CloseStream"}) connection.send(close_stream_message) # Yield results until connection closes while not connection_closed.is_set(): try: chunk = transcript_queue.get(timeout=0.1) if isinstance(chunk, Exception): raise chunk yield chunk except queue.Empty: continue # Get any remaining results while not transcript_queue.empty(): try: chunk = transcript_queue.get_nowait() if isinstance(chunk, Exception): raise chunk yield chunk except queue.Empty: break except Exception as e: raise ASRError(f"Deepgram streaming transcription error: {e}") def _prepare_audio_payload(self, file: Union[str, BinaryIO]) -> bytes: """Prepare audio payload for Deepgram API v5. Returns raw bytes instead of dict payload (v5 API change). """ if isinstance(file, str): with open(file, "rb") as audio_file: buffer_data = audio_file.read() else: if hasattr(file, "read"): buffer_data = file.read() else: raise ValueError( "File must be a file path string or file-like object" ) return buffer_data async def _load_and_prepare_audio( self, file: Union[str, BinaryIO] ) -> tuple[np.ndarray, int]: """Load and prepare audio file for streaming. Conversions performed only when necessary: - Stereo to mono: Required for multi-channel audio - Sample rate conversion: Required when input != 16kHz - Other formats: Error out as unsupported """ try: try: import soundfile as sf except ImportError: raise ASRError( "soundfile is required for audio processing. Install with: pip install soundfile" ) if isinstance(file, str): audio_data, original_sample_rate = sf.read(file) else: audio_data, original_sample_rate = sf.read(file) audio_data = np.asarray(audio_data, dtype=np.float32) # Convert to mono if stereo if len(audio_data.shape) > 1: if audio_data.shape[1] == 2: audio_data = np.mean(audio_data, axis=1) else: raise ASRError( f"Unsupported audio format: {audio_data.shape[1]} channels. Only mono and stereo are supported." ) # Resample to 16kHz if needed target_sample_rate = 16000 if original_sample_rate != target_sample_rate: try: from scipy import signal num_samples = int( len(audio_data) * target_sample_rate / original_sample_rate ) audio_data = signal.resample(audio_data, num_samples) except ImportError: raise ASRError( f"Audio resampling required but scipy not available. " f"Input is {original_sample_rate}Hz, need {target_sample_rate}Hz. " f"Install scipy or provide audio at {target_sample_rate}Hz." ) return np.asarray(audio_data, dtype=np.float32), target_sample_rate except Exception as e: if isinstance(e, ASRError): raise raise ASRError(f"Error loading audio file: {e}") def _send_audio_chunk(self, connection, audio_chunk: np.ndarray) -> None: """Send audio chunk data through the connection.""" streaming_chunk_size = 8000 # Match reference BLOCKSIZE (~0.5s @16kHz mono) send_delay = 0.01 for i in range(0, len(audio_chunk), streaming_chunk_size): piece = audio_chunk[i : i + streaming_chunk_size] if len(piece) < streaming_chunk_size: piece = np.pad( piece, (0, streaming_chunk_size - len(piece)), mode="constant" ) pcm16 = (piece * 32767).astype(np.int16).tobytes() connection.send(pcm16) time.sleep(send_delay) # Use synchronous sleep like reference def _parse_deepgram_response(self, response_dict: dict) -> TranscriptionResult: """Convert Deepgram API response to unified TranscriptionResult.""" try: results = response_dict.get("results", {}) channels = results.get("channels", []) if not channels or not channels[0].get("alternatives"): return TranscriptionResult( text="", language=None, confidence=None, task="transcribe" ) best_alternative = channels[0]["alternatives"][0] text = best_alternative.get("transcript", "") confidence = best_alternative.get("confidence", None) words = [ Word( word=word_data.get("word", ""), start=word_data.get("start", None), end=word_data.get("end", None), confidence=word_data.get("confidence", None), ) for word_data in best_alternative.get("words", []) ] segments = [] paragraphs = results.get("paragraphs", {}).get("paragraphs", []) for para in paragraphs: for sentence in para.get("sentences", []): segments.append( Segment( id=len(segments), seek=0, start=sentence.get("start", None), end=sentence.get("end", None), text=sentence.get("text", ""), tokens=[], temperature=0.0, avg_logprob=0.0, compression_ratio=0.0, no_speech_prob=0.0, ) ) alternatives_list = [ Alternative( transcript=alt.get("transcript", ""), confidence=alt.get("confidence", None), ) for alt in channels[0]["alternatives"][1:] ] channels_list = [ Channel( alternatives=[ Alternative( transcript=alt.get("transcript", ""), confidence=alt.get("confidence", None), ) for alt in channel.get("alternatives", []) ] ) for channel in channels ] metadata = response_dict.get("metadata", {}) return TranscriptionResult( text=text, language=results.get("language", None), confidence=confidence, task="transcribe", duration=metadata.get("duration", None) if metadata else None, segments=segments or None, words=words or None, channels=channels_list or None, alternatives=alternatives_list or None, utterances=results.get("utterances", []), paragraphs=results.get("paragraphs", None), topics=results.get("topics", []), intents=results.get("intents", []), sentiment=results.get("sentiment", None), summary=results.get("summary", None), metadata=metadata, ) except (KeyError, TypeError, IndexError) as e: raise ASRError(f"Error parsing Deepgram response: {e}") ================================================ FILE: aisuite/providers/deepseek_provider.py ================================================ """Deepseek provider for the aisuite.""" import os import openai from aisuite.provider import Provider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter # pylint: disable=too-few-public-methods class DeepseekProvider(Provider): """Provider for Deepseek.""" def __init__(self, **config): """ Initialize the DeepSeek provider with the given configuration. Pass the entire configuration dictionary to the OpenAI client constructor. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) if not config["api_key"]: raise ValueError( "DeepSeek API key is missing. Please provide it in the config or " "set the OPENAI_API_KEY environment variable." ) config["base_url"] = "https://api.deepseek.com" # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically # infer certain values from the environment variables. # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for # OPEN_AI_BASE_URL which has to be the deepseek url # Pass the entire config to the OpenAI client constructor self.client = openai.OpenAI(**config) # Using OpenAICompliantMessageConverter since DeepSeek's response format is # the same as OpenAI's. self.transformer = OpenAICompliantMessageConverter() def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. try: response = self.client.chat.completions.create( model=model, messages=messages, **kwargs, # Pass any additional arguments to the OpenAI API ) return self.transformer.convert_response(response.model_dump()) except Exception as e: raise LLMError(f"An error occurred: {e}") from e ================================================ FILE: aisuite/providers/fireworks_provider.py ================================================ import os import httpx import json from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall class FireworksMessageConverter: @staticmethod def convert_request(messages): """Convert messages to Fireworks format.""" transformed_messages = [] for message in messages: if isinstance(message, Message): message_dict = message.model_dump(mode="json") message_dict.pop("refusal", None) # Remove refusal field if present transformed_messages.append(message_dict) else: transformed_messages.append(message) return transformed_messages @staticmethod def convert_response(resp_json) -> ChatCompletionResponse: """Normalize the response from the Fireworks API to match OpenAI's response format.""" completion_response = ChatCompletionResponse() choice = resp_json["choices"][0] message = choice["message"] # Set basic message content completion_response.choices[0].message.content = message.get("content") completion_response.choices[0].message.role = message.get("role", "assistant") # Handle tool calls if present if "tool_calls" in message and message["tool_calls"] is not None: tool_calls = [] for tool_call in message["tool_calls"]: new_tool_call = ChatCompletionMessageToolCall( id=tool_call["id"], type=tool_call["type"], function={ "name": tool_call["function"]["name"], "arguments": tool_call["function"]["arguments"], }, ) tool_calls.append(new_tool_call) completion_response.choices[0].message.tool_calls = tool_calls return completion_response # Models that support tool calls: # [As of 01/20/2025 from https://docs.fireworks.ai/guides/function-calling] # Llama 3.1 405B Instruct # Llama 3.1 70B Instruct # Qwen 2.5 72B Instruct # Mixtral MoE 8x22B Instruct # Firefunction-v2: Latest and most performant model, optimized for complex function calling scenarios (on-demand only) # Firefunction-v1: Previous generation, Mixtral-based function calling model optimized for fast routing and structured output (on-demand only) class FireworksProvider(Provider): """ Fireworks AI Provider using httpx for direct API calls. """ BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions" def __init__(self, **config): """ Initialize the Fireworks provider with the given configuration. The API key is fetched from the config or environment variables. """ self.api_key = config.get("api_key", os.getenv("FIREWORKS_API_KEY")) if not self.api_key: raise ValueError( "Fireworks API key is missing. Please provide it in the config or set the FIREWORKS_API_KEY environment variable." ) # Optionally set a custom timeout (default to 30s) self.timeout = config.get("timeout", 30) self.transformer = FireworksMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the Fireworks AI chat completions endpoint using httpx. """ # Remove 'stream' from kwargs if present kwargs.pop("stream", None) # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) # Prepare the request payload data = { "model": model, "messages": transformed_messages, } # Add tools if provided if "tools" in kwargs: data["tools"] = kwargs["tools"] kwargs.pop("tools") # Add tool_choice if provided if "tool_choice" in kwargs: data["tool_choice"] = kwargs["tool_choice"] kwargs.pop("tool_choice") # Add remaining kwargs data.update(kwargs) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } try: # Make the request to Fireworks AI endpoint. response = httpx.post( self.BASE_URL, json=data, headers=headers, timeout=self.timeout ) response.raise_for_status() return self.transformer.convert_response(response.json()) except httpx.HTTPStatusError as error: error_message = ( f"The request failed with status code: {error.status_code}\n" ) error_message += f"Headers: {error.headers}\n" error_message += error.response.text raise LLMError(error_message) except Exception as e: raise LLMError(f"An error occurred: {e}") def _normalize_response(self, response_data): """ Normalize the response to a common format (ChatCompletionResponse). """ normalized_response = ChatCompletionResponse() normalized_response.choices[0].message.content = response_data["choices"][0][ "message" ]["content"] return normalized_response ================================================ FILE: aisuite/providers/google_provider.py ================================================ """The interface to Google's Vertex AI.""" import os import json from typing import List, Dict, Any, Optional, Union, BinaryIO, AsyncGenerator import vertexai from vertexai.generative_models import ( GenerativeModel, GenerationConfig, Content, Part, Tool, FunctionDeclaration, ) import pprint from aisuite.framework import ChatCompletionResponse, Message from aisuite.framework.message import ( TranscriptionResult, Word, Segment, Alternative, StreamingTranscriptionChunk, ) from aisuite.provider import Provider, ASRError, Audio DEFAULT_TEMPERATURE = 0.7 ENABLE_DEBUG_MESSAGES = False # Links. # https://codelabs.developers.google.com/codelabs/gemini-function-calling#6 # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#chat-samples class GoogleMessageConverter: @staticmethod def convert_user_role_message(message: Dict[str, Any]) -> Content: """Convert user or system messages to Google Vertex AI format.""" parts = [Part.from_text(message["content"])] return Content(role="user", parts=parts) @staticmethod def convert_assistant_role_message(message: Dict[str, Any]) -> Content: """Convert assistant messages to Google Vertex AI format.""" if "tool_calls" in message and message["tool_calls"]: # Handle function calls tool_call = message["tool_calls"][ 0 ] # Assuming single function call for now function_call = tool_call["function"] # Create a Part from the function call parts = [ Part.from_dict( { "function_call": { "name": function_call["name"], # "arguments": json.loads(function_call["arguments"]) } } ) ] # return Content(role="function", parts=parts) else: # Handle regular text messages parts = [Part.from_text(message["content"])] # return Content(role="model", parts=parts) return Content(role="model", parts=parts) @staticmethod def convert_tool_role_message(message: Dict[str, Any]) -> Part: """Convert tool messages to Google Vertex AI format.""" if "content" not in message: raise ValueError("Tool result message must have a content field") try: content_json = json.loads(message["content"]) part = Part.from_function_response( name=message["name"], response=content_json ) # TODO: Return Content instead of Part. But returning Content is not working. return part except json.JSONDecodeError: raise ValueError("Tool result message must be valid JSON") @staticmethod def convert_request(messages: List[Dict[str, Any]]) -> List[Content]: """Convert messages to Google Vertex AI format.""" # Convert all messages to dicts if they're Message objects messages = [ message.model_dump() if hasattr(message, "model_dump") else message for message in messages ] formatted_messages = [] for message in messages: if message["role"] == "tool": vertex_message = GoogleMessageConverter.convert_tool_role_message( message ) if vertex_message: formatted_messages.append(vertex_message) elif message["role"] == "assistant": formatted_messages.append( GoogleMessageConverter.convert_assistant_role_message(message) ) else: # user or system role formatted_messages.append( GoogleMessageConverter.convert_user_role_message(message) ) return formatted_messages @staticmethod def convert_response(response) -> ChatCompletionResponse: """Normalize the response from Vertex AI to match OpenAI's response format.""" openai_response = ChatCompletionResponse() if ENABLE_DEBUG_MESSAGES: print("Dumping the response") pprint.pprint(response) # TODO: We need to go through each part, because function call may not be the first part. # Currently, we are only handling the first part, but this is not enough. # # This is a valid response: # candidates { # content { # role: "model" # parts { # text: "The current temperature in San Francisco is 72 degrees Celsius. \n\n" # } # parts { # function_call { # name: "is_it_raining" # args { # fields { # key: "location" # value { # string_value: "San Francisco" # } # } # } # } # } # } # finish_reason: STOP # Check if the response contains function calls # Note: Just checking if the function_call attribute exists is not enough, # it is important to check if the function_call is not None. if ( hasattr(response.candidates[0].content.parts[0], "function_call") and response.candidates[0].content.parts[0].function_call ): function_call = response.candidates[0].content.parts[0].function_call # args is a MapComposite. # Convert the MapComposite to a dictionary args_dict = {} # Another way to try is: args_dict = dict(function_call.args) for key, value in function_call.args.items(): args_dict[key] = value if ENABLE_DEBUG_MESSAGES: print("Dumping the args_dict") pprint.pprint(args_dict) openai_response.choices[0].message = { "role": "assistant", "content": None, "tool_calls": [ { "type": "function", "id": f"call_{hash(function_call.name)}", # Generate a unique ID "function": { "name": function_call.name, "arguments": json.dumps(args_dict), }, } ], "refusal": None, } openai_response.choices[0].message = Message( **openai_response.choices[0].message ) openai_response.choices[0].finish_reason = "tool_calls" else: # Handle regular text response openai_response.choices[0].message.content = ( response.candidates[0].content.parts[0].text ) openai_response.choices[0].finish_reason = "stop" return openai_response class GoogleProvider(Provider): """Implements the ProviderInterface for interacting with Google's Vertex AI.""" def __init__(self, **config): """Set up the Google AI client with a project ID.""" super().__init__() self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID") self.location = config.get("region") or os.getenv("GOOGLE_REGION") self.app_creds_path = config.get("application_credentials") or os.getenv( "GOOGLE_APPLICATION_CREDENTIALS" ) if not self.project_id or not self.location or not self.app_creds_path: raise EnvironmentError( "Missing one or more required Google environment variables: " "GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. " "Please refer to the setup guide: /guides/google.md." ) vertexai.init(project=self.project_id, location=self.location) self.transformer = GoogleMessageConverter() # Initialize Speech client lazily self._speech_client = None # Initialize audio functionality self.audio = GoogleAudio(self) def chat_completions_create(self, model, messages, **kwargs): """Request chat completions from the Google AI API. Args: ---- model (str): Identifies the specific provider/model to use. messages (list of dict): A list of message objects in chat history. kwargs (dict): Optional arguments for the Google AI API. Returns: ------- The ChatCompletionResponse with the completion result. """ # Set the temperature if provided, otherwise use the default temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE) # Convert messages to Vertex AI format message_history = self.transformer.convert_request(messages) # Handle tools if provided tools = None if "tools" in kwargs: tools = [ Tool( function_declarations=[ FunctionDeclaration( name=tool["function"]["name"], description=tool["function"].get("description", ""), parameters={ "type": "object", "properties": { param_name: { "type": param_info.get("type", "string"), "description": param_info.get( "description", "" ), **( {"enum": param_info["enum"]} if "enum" in param_info else {} ), } for param_name, param_info in tool["function"][ "parameters" ]["properties"].items() }, "required": tool["function"]["parameters"].get( "required", [] ), }, ) for tool in kwargs["tools"] ] ) ] # Create the GenerativeModel model = GenerativeModel( model, generation_config=GenerationConfig(temperature=temperature), tools=tools, ) if ENABLE_DEBUG_MESSAGES: print("Dumping the message_history") pprint.pprint(message_history) # Start chat and get response chat = model.start_chat(history=message_history[:-1]) last_message = message_history[-1] # If the last message is a function response, send the Part object directly # Otherwise, send just the text content message_to_send = ( Content(role="function", parts=[last_message]) if isinstance(last_message, Part) else last_message.parts[0].text ) # response = chat.send_message(message_to_send) response = chat.send_message(message_to_send) # Convert and return the response return self.transformer.convert_response(response) @property def speech_client(self): """Lazy initialization of Google Cloud Speech client.""" if self._speech_client is None: try: from google.cloud import speech self._speech_client = speech.SpeechClient() except ImportError: raise ImportError( "google-cloud-speech is required for ASR functionality. " "Install it with: pip install google-cloud-speech" ) return self._speech_client # Audio Classes class GoogleAudio(Audio): """Google Audio functionality container.""" def __init__(self, provider): super().__init__() self.provider = provider self.transcriptions = self.Transcriptions(provider) class Transcriptions(Audio.Transcription): """Google Audio Transcriptions functionality.""" def __init__(self, provider): self.provider = provider def create( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> TranscriptionResult: """ Create audio transcription using Google Cloud Speech-to-Text API. All parameters are already validated and mapped by the Client layer. This is a simple pass-through to the Google API. """ try: from google.cloud import speech # Set defaults kwargs["model"] = model if model != "default" else "latest_long" kwargs.setdefault("sample_rate_hertz", 16000) kwargs.setdefault("enable_automatic_punctuation", True) audio_data = self._read_audio_data(file) audio = speech.RecognitionAudio(content=audio_data) config = self._build_recognition_config(kwargs, speech, file) response = self.provider.speech_client.recognize( config=config, audio=audio ) return self._parse_google_response(response) except ImportError: raise ASRError( "google-cloud-speech is required for ASR functionality. " "Install it with: pip install google-cloud-speech" ) except Exception as e: raise ASRError(f"Google Speech-to-Text error: {e}") from e async def create_stream_output( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> AsyncGenerator[StreamingTranscriptionChunk, None]: """ Create streaming audio transcription using Google Cloud Speech-to-Text API. All parameters are already validated and mapped by the Client layer. This implementation handles streaming with Google's API. """ try: from google.cloud import speech # Set defaults kwargs["model"] = model if model != "default" else "latest_long" kwargs.setdefault("sample_rate_hertz", 16000) kwargs.setdefault("enable_automatic_punctuation", True) config = self._build_recognition_config(kwargs, speech, file) streaming_config = speech.StreamingRecognitionConfig( config=config, interim_results=True, single_utterance=False ) audio_data = self._read_audio_data(file) request_generator = self._create_streaming_requests( speech, streaming_config, audio_data ) responses = self.provider.speech_client.streaming_recognize( config=streaming_config, requests=request_generator ) for response in responses: for result in response.results: if result.alternatives: alternative = result.alternatives[0] yield StreamingTranscriptionChunk( text=alternative.transcript, is_final=result.is_final, confidence=getattr(alternative, "confidence", None), ) except ImportError: raise ASRError( "google-cloud-speech is required for ASR functionality. " "Install it with: pip install google-cloud-speech" ) except Exception as e: raise ASRError(f"Google Speech-to-Text streaming error: {e}") from e def _read_audio_data(self, file: Union[str, BinaryIO]) -> bytes: """Read audio data from file or file-like object.""" if isinstance(file, str): with open(file, "rb") as audio_file: return audio_file.read() else: return file.read() def _detect_audio_encoding(self, file: Union[str, BinaryIO], speech): """Detect audio encoding based on file extension or content.""" if isinstance(file, str): # File path - detect by extension file_lower = file.lower() if file_lower.endswith(".mp3"): return speech.RecognitionConfig.AudioEncoding.MP3 elif file_lower.endswith(".flac"): return speech.RecognitionConfig.AudioEncoding.FLAC elif file_lower.endswith(".wav"): return speech.RecognitionConfig.AudioEncoding.LINEAR16 elif file_lower.endswith(".ogg"): return speech.RecognitionConfig.AudioEncoding.OGG_OPUS elif file_lower.endswith(".webm"): return speech.RecognitionConfig.AudioEncoding.WEBM_OPUS # Default to LINEAR16 for unknown formats return speech.RecognitionConfig.AudioEncoding.LINEAR16 def _build_recognition_config( self, params: dict, speech, file: Union[str, BinaryIO] ): """Build Google Speech RecognitionConfig from parameters.""" # Auto-detect encoding if not specified encoding = params.get("encoding") if encoding is None: encoding = self._detect_audio_encoding(file, speech) config_params = { "encoding": encoding, "sample_rate_hertz": params.get("sample_rate_hertz", 16000), "language_code": params.get("language_code", "en-US"), "enable_word_time_offsets": True, "enable_word_confidence": True, "enable_automatic_punctuation": params.get( "enable_automatic_punctuation", True ), "model": params["model"], } for param in ["max_alternatives", "profanity_filter", "speech_contexts"]: if param in params: config_params[param] = params[param] return speech.RecognitionConfig(**config_params) def _create_streaming_requests( self, speech, streaming_config, audio_data: bytes ): """Create streaming requests generator for Google Speech API.""" def request_generator(): chunk_size = 8192 for i in range(0, len(audio_data), chunk_size): chunk = audio_data[i : i + chunk_size] yield speech.StreamingRecognizeRequest(audio_content=chunk) return request_generator() def _parse_google_response(self, response) -> TranscriptionResult: """Convert Google Speech-to-Text response to unified TranscriptionResult.""" if not response.results or not response.results[0].alternatives: return TranscriptionResult( text="", language=None, confidence=None, task="transcribe" ) best_result = response.results[0] best_alternative = best_result.alternatives[0] text = best_alternative.transcript confidence = getattr(best_alternative, "confidence", None) words = [] if hasattr(best_alternative, "words") and best_alternative.words: words = [ Word( word=word.word, start=( word.start_time.total_seconds() if hasattr(word, "start_time") else 0.0 ), end=( word.end_time.total_seconds() if hasattr(word, "end_time") else 0.0 ), confidence=getattr(word, "confidence", None), ) for word in best_alternative.words ] alternatives = [ Alternative( transcript=alt.transcript, confidence=getattr(alt, "confidence", None), ) for alt in best_result.alternatives ] segments = [] if words: segments = [ Segment( id=0, seek=0, start=words[0].start, end=words[-1].end, text=text, tokens=[], temperature=0.0, avg_logprob=0.0, compression_ratio=0.0, no_speech_prob=0.0, ) ] return TranscriptionResult( text=text, language=None, confidence=confidence, task="transcribe", words=words or None, alternatives=alternatives or None, segments=segments or None, ) ================================================ FILE: aisuite/providers/groq_provider.py ================================================ import os import groq from aisuite.provider import Provider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter # Implementation of Groq provider. # Groq's message format is same as OpenAI's. # Tool calling specification is also exactly the same as OpenAI's. # Links: # https://console.groq.com/docs/tool-use # Groq supports tool calling for the following models, as of 16th Nov 2024: # llama3-groq-70b-8192-tool-use-preview # llama3-groq-8b-8192-tool-use-preview # llama-3.1-70b-versatile # llama-3.1-8b-instant # llama3-70b-8192 # llama3-8b-8192 # mixtral-8x7b-32768 (parallel tool use not supported) # gemma-7b-it (parallel tool use not supported) # gemma2-9b-it (parallel tool use not supported) class GroqMessageConverter(OpenAICompliantMessageConverter): """ Groq-specific message converter if needed """ pass class GroqProvider(Provider): def __init__(self, **config): """ Initialize the Groq provider with the given configuration. Pass the entire configuration dictionary to the Groq client constructor. """ # Ensure API key is provided either in config or via environment variable self.api_key = config.get("api_key", os.getenv("GROQ_API_KEY")) if not self.api_key: raise ValueError( "Groq API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable." ) config["api_key"] = self.api_key self.client = groq.Groq(**config) self.transformer = GroqMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the Groq chat completions endpoint using the official client. """ try: # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) response = self.client.chat.completions.create( model=model, messages=transformed_messages, **kwargs, # Pass any additional arguments to the Groq API ) return self.transformer.convert_response(response.model_dump()) except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/providers/huggingface_provider.py ================================================ import os import json import time from typing import Union, BinaryIO import requests from huggingface_hub import InferenceClient from aisuite.provider import Provider, LLMError, ASRError, Audio from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, TranscriptionResult, Word class HuggingfaceProvider(Provider): """ HuggingFace Provider using the official InferenceClient. This provider supports calls to HF serverless Inference Endpoints which use Text Generation Inference (TGI) as the backend. TGI is OpenAI protocol compliant. https://huggingface.co/inference-endpoints/ """ def __init__(self, **config): """ Initialize the provider with the given configuration. The token is fetched from the config or environment variables. """ # Ensure API key is provided either in config or via environment variable self.token = ( config.get("token") or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") ) if not self.token: raise ValueError( "Hugging Face token is missing. Please provide it in the config or set the HF_TOKEN or HUGGINGFACE_API_KEY environment variable." ) # Initialize the InferenceClient with the specified model and timeout if provided self.model = config.get("model") self.timeout = config.get("timeout", 30) self.client = InferenceClient( token=self.token, model=self.model, timeout=self.timeout ) # Initialize audio functionality super().__init__() self.audio = HuggingfaceAudio(self.token, self.timeout) def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the Inference API endpoint using InferenceClient. """ # Validate and transform messages transformed_messages = [] for message in messages: if isinstance(message, Message): transformed_message = self.transform_from_message(message) elif isinstance(message, dict): transformed_message = message else: raise ValueError(f"Invalid message format: {message}") # Ensure 'content' is a non-empty string if ( "content" not in transformed_message or transformed_message["content"] is None ): transformed_message["content"] = "" transformed_messages.append(transformed_message) try: # Prepare the payload payload = { "messages": transformed_messages, **kwargs, # Include other parameters like temperature, max_tokens, etc. } # Make the API call using the client response = self.client.chat_completion(model=model, **payload) return self._normalize_response(response) except Exception as e: raise LLMError(f"An error occurred: {e}") def transform_from_message(self, message: Message): """Transform framework Message to a format that HuggingFace understands.""" # Ensure content is a string content = message.content if message.content is not None else "" # Transform the message transformed_message = { "role": message.role, "content": content, } # Include tool_calls if present if message.tool_calls: transformed_message["tool_calls"] = [ { "id": tool_call.id, "function": { "name": tool_call.function.name, "arguments": tool_call.function.arguments, }, "type": tool_call.type, } for tool_call in message.tool_calls ] return transformed_message def transform_to_message(self, message_dict: dict): """Transform HuggingFace message (dict) to a format that the framework Message understands.""" # Ensure required fields are present message_dict.setdefault("content", "") # Set empty string if content is missing message_dict.setdefault("refusal", None) # Set None if refusal is missing message_dict.setdefault("tool_calls", None) # Set None if tool_calls is missing # Handle tool calls if present and not None if message_dict.get("tool_calls"): for tool_call in message_dict["tool_calls"]: if "function" in tool_call: # Ensure function arguments are stringified if isinstance(tool_call["function"].get("arguments"), dict): tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] ) return Message(**message_dict) def _normalize_response(self, response_data): """ Normalize the response to a common format (ChatCompletionResponse). """ normalized_response = ChatCompletionResponse() message_data = response_data["choices"][0]["message"] normalized_response.choices[0].message = self.transform_to_message(message_data) return normalized_response # Audio Classes class HuggingfaceAudio(Audio): """Hugging Face Audio functionality container.""" def __init__(self, token, timeout=120): super().__init__() self.transcriptions = self.Transcriptions(token, timeout) class Transcriptions(Audio.Transcription): """Hugging Face Audio Transcriptions functionality.""" def __init__(self, token, timeout=120): self.token = token self.timeout = timeout def create( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> TranscriptionResult: """ Create audio transcription using Hugging Face Inference API. All parameters are already validated and mapped by the Client layer. This makes an HTTP POST request to the Hugging Face Inference API. Note: Whisper-based models have a 30-second processing window. For longer audio, users should deploy custom Inference Endpoints. """ try: # Extract model ID from format "huggingface:model-id" model_id = model.split(":", 1)[1] if ":" in model else model # Prepare API endpoint url = f"https://api-inference.huggingface.co/models/{model_id}" # Prepare audio data if isinstance(file, str): with open(file, "rb") as audio_file: audio_bytes = audio_file.read() content_type = self._detect_content_type(file) else: audio_bytes = file.read() # Default to wav for file-like objects content_type = "audio/wav" # Prepare headers headers = { "Authorization": f"Bearer {self.token}", "Content-Type": content_type, } # First attempt without wait_for_model try: response = requests.post( url, headers=headers, data=audio_bytes, timeout=self.timeout, ) response.raise_for_status() except requests.exceptions.HTTPError as e: # If 503 (model loading), retry with x-wait-for-model header if e.response.status_code == 503: headers["x-wait-for-model"] = "true" response = requests.post( url, headers=headers, data=audio_bytes, timeout=self.timeout, ) response.raise_for_status() else: raise # Parse response response_data = response.json() return self._parse_huggingface_response(response_data, model_id) except requests.exceptions.RequestException as e: raise ASRError(f"Hugging Face transcription error: {e}") from e except Exception as e: raise ASRError(f"Hugging Face transcription error: {e}") from e def _detect_content_type(self, file_path: str) -> str: """Detect audio content type from file extension.""" if file_path.lower().endswith(".wav"): return "audio/wav" elif file_path.lower().endswith(".mp3"): return "audio/mpeg" # HF API requires audio/mpeg for MP3 elif file_path.lower().endswith(".flac"): return "audio/flac" else: # Default to wav if unknown return "audio/wav" def _parse_huggingface_response( self, response_data, model_id: str ) -> TranscriptionResult: """ Parse Hugging Face API response into TranscriptionResult. Response format can vary: - Standard: {"text": "...", "chunks": [...]} - Text only: {"text": "..."} - Some models may use different keys """ try: # Extract text if isinstance(response_data, dict): text = response_data.get("text", "") elif isinstance(response_data, str): # Some models return plain string text = response_data else: text = str(response_data) # Extract words from chunks if available words = None if isinstance(response_data, dict) and "chunks" in response_data: chunks = response_data["chunks"] if chunks: words = [] for chunk in chunks: if isinstance(chunk, dict): word_text = chunk.get("text", "") timestamp = chunk.get("timestamp") # timestamp can be [start, end] or (start, end) start, end = None, None if timestamp and len(timestamp) >= 2: start, end = timestamp[0], timestamp[1] words.append( Word( word=word_text, start=start, end=end, confidence=None, # HF doesn't provide confidence ) ) return TranscriptionResult( text=text, language=None, # HF API doesn't return language confidence=None, # HF API doesn't return confidence words=words, task="transcribe", ) except (KeyError, TypeError, IndexError) as e: raise ASRError(f"Error parsing Hugging Face response: {e}") ================================================ FILE: aisuite/providers/inception_provider.py ================================================ import openai import os from aisuite.provider import Provider, LLMError class InceptionProvider(Provider): def __init__(self, **config): """ Initialize the Inception provider with the given configuration. Pass the entire configuration dictionary to the Inception client constructor using openai. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("INCEPTION_API_KEY")) if not config["api_key"]: raise ValueError( "Inception API key is missing. Please provide it in the config or set the INCEPTION_API_KEY environment variable." ) config["base_url"] = "https://api.inceptionlabs.ai/v1" # Pass the entire config to the Inception client constructor using openai self.client = openai.OpenAI(**config) def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by Inception will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. try: response = self.client.chat.completions.create( model=model, messages=messages, **kwargs, # Pass any additional arguments to the Inception API ) return response except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/providers/lmstudio_provider.py ================================================ import os import httpx from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse class LmstudioProvider(Provider): """ LM Studio Provider that makes HTTP calls. Inspired by OllamaProvider in aisuite. It uses the /v1/chat/completions endpoint. Read more here - https://lmstudio.ai/docs/api and on your local instance in the "Developer" tab. If LMSTUDIO_API_URL is not set and not passed in config, then it will default to "http://localhost:1234" """ _CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions" _CONNECT_ERROR_MESSAGE = "LM Studio is likely not running. Start LM Studio by running `ollama serve` on your host." def __init__(self, **config): """ Initialize the LM Studio provider with the given configuration. """ self.url = config.get("api_url") or os.getenv( "LMSTUDIO_API_URL", "http://localhost:1234" ) # Optionally set a custom timeout (default to 300s) self.timeout = config.get("timeout", 300) def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the chat completions endpoint using httpx. """ kwargs["stream"] = False data = { "model": model, "messages": messages, **kwargs, # Pass any additional arguments to the API } try: response = httpx.post( self.url.rstrip("/") + self._CHAT_COMPLETION_ENDPOINT, json=data, timeout=self.timeout, ) response.raise_for_status() except httpx.ConnectError: # Handle connection errors raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}") except httpx.HTTPStatusError as http_err: raise LLMError(f"LM Studio request failed: {http_err}") except Exception as e: raise LLMError(f"An error occurred: {e}") # Return the normalized response return self._normalize_response(response.json()) def _normalize_response(self, response_data): """ Normalize the API response to a common format (ChatCompletionResponse). """ normalized_response = ChatCompletionResponse() normalized_response.choices[0].message.content = response_data["choices"][0][ "message" ]["content"] return normalized_response ================================================ FILE: aisuite/providers/message_converter.py ================================================ """Base message converter for OpenAI-compliant providers.""" from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import ( Message, ChatCompletionMessageToolCall, CompletionUsage, ) class OpenAICompliantMessageConverter: """ Base class for message converters that are compatible with OpenAI's API. """ # Class variable that derived classes can override tool_results_as_strings = False @staticmethod def convert_request(messages): """Convert messages to OpenAI-compatible format.""" transformed_messages = [] for message in messages: tmsg = None if isinstance(message, Message): message_dict = message.model_dump(mode="json") message_dict.pop("refusal", None) # Remove refusal field if present tmsg = message_dict else: tmsg = message # Check if tmsg is a dict, otherwise get role attribute role = tmsg["role"] if isinstance(tmsg, dict) else tmsg.role if role == "tool": if OpenAICompliantMessageConverter.tool_results_as_strings: # Handle both dict and object cases for content if isinstance(tmsg, dict): tmsg["content"] = str(tmsg["content"]) else: tmsg.content = str(tmsg.content) transformed_messages.append(tmsg) return transformed_messages def convert_response(self, response_data) -> ChatCompletionResponse: """Normalize the response to match OpenAI's response format.""" completion_response = ChatCompletionResponse() choice = response_data["choices"][0] message = choice["message"] # Set basic message content completion_response.choices[0].message.content = message["content"] completion_response.choices[0].message.role = message.get("role", "assistant") # Conditionally parse usage data if it exists. if usage_data := response_data.get("usage"): completion_response.usage = self.get_completion_usage(usage_data) # Handle tool calls if present if "tool_calls" in message and message["tool_calls"] is not None: tool_calls = [] for tool_call in message["tool_calls"]: tool_calls.append( ChatCompletionMessageToolCall( id=tool_call.get("id"), type="function", # Always set to "function" as it's the only valid value function=tool_call.get("function"), ) ) completion_response.choices[0].message.tool_calls = tool_calls return completion_response def get_completion_usage(self, usage_data: dict): """Get the usage statistics from a usage data dictionary.""" return CompletionUsage( completion_tokens=usage_data.get("completion_tokens"), prompt_tokens=usage_data.get("prompt_tokens"), total_tokens=usage_data.get("total_tokens"), prompt_tokens_details=usage_data.get("prompt_tokens_details"), completion_tokens_details=usage_data.get("completion_tokens_details"), ) ================================================ FILE: aisuite/providers/mistral_provider.py ================================================ """Mistral provider for the aisuite.""" import os from mistralai import Mistral from aisuite.framework import ChatCompletionResponse from aisuite.provider import Provider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter # Implementation of Mistral provider. # Mistral's message format is the same as OpenAI's. Just different class names, # but fully cross-compatible. # Links: # https://docs.mistral.ai/capabilities/function_calling/ class MistralMessageConverter(OpenAICompliantMessageConverter): """ Mistral-specific message converter """ def convert_response(self, response_data) -> ChatCompletionResponse: """Convert Mistral's response to our standard format.""" # Convert Mistral's response object to dict format response_dict = response_data.model_dump() return super().convert_response(response_dict) # Function calling is available for the following models: # [As of 01/19/2025 from https://docs.mistral.ai/capabilities/function_calling/] # Mistral Large # Mistral Small # Codestral 22B # Ministral 8B # Ministral 3B # Pixtral 12B # Mixtral 8x22B # Mistral Nemo # pylint: disable=too-few-public-methods class MistralProvider(Provider): """ Mistral AI Provider using the official Mistral client. """ def __init__(self, **config): """ Initialize the Mistral provider with the given configuration. Pass the entire configuration dictionary to the Mistral client constructor. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("MISTRAL_API_KEY")) if not config["api_key"]: raise ValueError( "Mistral API key is missing. Please provide it in the config or set the " "MISTRAL_API_KEY environment variable." ) self.client = Mistral(**config) self.transformer = MistralMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to Mistral using the official client. """ try: # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) # Make the request to Mistral response = self.client.chat.complete( model=model, messages=transformed_messages, **kwargs ) return self.transformer.convert_response(response) except Exception as e: raise LLMError(f"An error occurred: {e}") from e ================================================ FILE: aisuite/providers/nebius_provider.py ================================================ import os from aisuite.provider import Provider from openai import Client BASE_URL = "https://api.studio.nebius.ai/v1" # TODO(rohitcp): This needs to be added to our internal testbed. Tool calling not tested. class NebiusProvider(Provider): def __init__(self, **config): """ Initialize the Nebius AI Studio provider with the given configuration. Pass the entire configuration dictionary to the OpenAI client constructor. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("NEBIUS_API_KEY")) if not config["api_key"]: raise ValueError( "Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" ) config["base_url"] = BASE_URL # Pass the entire config to the OpenAI client constructor self.client = Client(**config) def chat_completions_create(self, model, messages, **kwargs): return self.client.chat.completions.create( model=model, messages=messages, **kwargs # Pass any additional arguments to the Nebius API ) ================================================ FILE: aisuite/providers/ollama_provider.py ================================================ import os import httpx from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse class OllamaProvider(Provider): """ Ollama Provider that makes HTTP calls instead of using SDK. It uses the /api/chat endpoint. Read more here - https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion If OLLAMA_API_URL is not set and not passed in config, then it will default to "http://localhost:11434" """ _CHAT_COMPLETION_ENDPOINT = "/api/chat" _CONNECT_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host." def __init__(self, **config): """ Initialize the Ollama provider with the given configuration. """ self.url = config.get("api_url") or os.getenv( "OLLAMA_API_URL", "http://localhost:11434" ) # Optionally set a custom timeout (default to 30s) self.timeout = config.get("timeout", 30) def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the chat completions endpoint using httpx. """ kwargs["stream"] = False data = { "model": model, "messages": messages, **kwargs, # Pass any additional arguments to the API } try: response = httpx.post( self.url.rstrip("/") + self._CHAT_COMPLETION_ENDPOINT, json=data, timeout=self.timeout, ) response.raise_for_status() except httpx.ConnectError: # Handle connection errors raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}") except httpx.HTTPStatusError as http_err: raise LLMError(f"Ollama request failed: {http_err}") except Exception as e: raise LLMError(f"An error occurred: {e}") # Return the normalized response return self._normalize_response(response.json()) def _normalize_response(self, response_data): """ Normalize the API response to a common format (ChatCompletionResponse). """ normalized_response = ChatCompletionResponse() normalized_response.choices[0].message.content = response_data["message"][ "content" ] return normalized_response ================================================ FILE: aisuite/providers/openai_provider.py ================================================ import openai import os from typing import Union, BinaryIO, AsyncGenerator from aisuite.provider import Provider, LLMError, ASRError, Audio from aisuite.providers.message_converter import OpenAICompliantMessageConverter from aisuite.framework.message import ( TranscriptionResult, Segment, Word, StreamingTranscriptionChunk, ) class OpenaiProvider(Provider): def __init__(self, **config): """ Initialize the OpenAI provider with the given configuration. Pass the entire configuration dictionary to the OpenAI client constructor. """ # Ensure API key is provided either in config or via environment variable config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) if not config["api_key"]: raise ValueError( "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." ) # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically # infer certain values from the environment variables. # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. # Pass the entire config to the OpenAI client constructor self.client = openai.OpenAI(**config) self.transformer = OpenAICompliantMessageConverter() # Initialize audio functionality super().__init__() self.audio = OpenAIAudio(self.client) def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. try: transformed_messages = self.transformer.convert_request(messages) response = self.client.chat.completions.create( model=model, messages=transformed_messages, **kwargs, # Pass any additional arguments to the OpenAI API ) return response except Exception as e: raise LLMError(f"An error occurred: {e}") # Audio Classes class OpenAIAudio(Audio): """OpenAI Audio functionality container.""" def __init__(self, client): super().__init__() self.transcriptions = self.Transcriptions(client) class Transcriptions(Audio.Transcription): """OpenAI Audio Transcriptions functionality.""" def __init__(self, client): self.client = client def create( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> TranscriptionResult: """ Create audio transcription using OpenAI Whisper API. All parameters are already validated and mapped by the Client layer. This is a simple pass-through to the OpenAI API. """ try: # Handle TranscriptionOptions object if passed if "options" in kwargs: options = kwargs.pop("options") # Extract all non-None attributes from options object if hasattr(options, "__dict__"): for key, value in options.__dict__.items(): if value is not None and key not in kwargs: kwargs[key] = value # Handle timestamp_granularities requirement if "timestamp_granularities" in kwargs: # OpenAI requires verbose_json format for timestamp_granularities kwargs["response_format"] = "verbose_json" # Handle file input if isinstance(file, str): with open(file, "rb") as audio_file: response = self.client.audio.transcriptions.create( file=audio_file, model=model, **kwargs ) else: response = self.client.audio.transcriptions.create( file=file, model=model, **kwargs ) return self._parse_openai_response(response) except Exception as e: raise ASRError(f"OpenAI transcription error: {e}") from e async def create_stream_output( self, model: str, file: Union[str, BinaryIO], **kwargs, ) -> AsyncGenerator[StreamingTranscriptionChunk, None]: """ Create streaming audio transcription using OpenAI Whisper API. All parameters are already validated and mapped by the Client layer. This is a simple pass-through to the OpenAI API with streaming enabled. """ try: # Handle TranscriptionOptions object if passed if "options" in kwargs: options = kwargs.pop("options") # Extract all non-None attributes from options object if hasattr(options, "__dict__"): for key, value in options.__dict__.items(): if value is not None and key not in kwargs: kwargs[key] = value # Enable streaming kwargs["stream"] = True # Handle timestamp_granularities requirement if "timestamp_granularities" in kwargs: # OpenAI requires verbose_json format for timestamp_granularities if ( "response_format" in kwargs and kwargs["response_format"] != "verbose_json" ): raise ASRError( f"OpenAI timestamp_granularities requires response_format='verbose_json', " f"but got '{kwargs['response_format']}'. " f"Either remove timestamp_granularities or use response_format='verbose_json'." ) else: kwargs["response_format"] = "verbose_json" try: if isinstance(file, str): with open(file, "rb") as audio_file: response_stream = self.client.audio.transcriptions.create( file=audio_file, model=model, **kwargs ) else: response_stream = self.client.audio.transcriptions.create( file=file, model=model, **kwargs ) # Process streaming response - handle event types for event in response_stream: # Handle TranscriptionTextDeltaEvent (incremental text) if ( hasattr(event, "type") and event.type == "transcript.text.delta" ): if hasattr(event, "delta") and event.delta: yield StreamingTranscriptionChunk( text=event.delta, is_final=False, # Delta events are interim confidence=getattr(event, "confidence", None), ) # Handle TranscriptionTextDoneEvent (final complete text) elif ( hasattr(event, "type") and event.type == "transcript.text.done" ): if hasattr(event, "text") and event.text: yield StreamingTranscriptionChunk( text=event.text, is_final=True, # Done event is final confidence=getattr(event, "confidence", None), ) except Exception as stream_error: raise ASRError( f"OpenAI streaming transcription error: {stream_error}" ) from stream_error except Exception as e: raise ASRError(f"OpenAI streaming transcription error: {e}") from e def _parse_openai_response(self, response) -> TranscriptionResult: """Parse OpenAI API response into TranscriptionResult.""" text = response.text if hasattr(response, "text") else "" language = getattr(response, "language", "unknown") # Parse segments if available segments = [] if hasattr(response, "segments") and response.segments: for seg in response.segments: words = [] if hasattr(seg, "words") and seg.words: for word in seg.words: words.append( Word( word=word.word, start=word.start, end=word.end, confidence=getattr(word, "confidence", None), ) ) segments.append( Segment( id=getattr(seg, "id", 0), seek=getattr(seg, "seek", 0), text=seg.text, start=seg.start, end=seg.end, words=words, confidence=getattr(seg, "avg_logprob", None), ) ) return TranscriptionResult( text=text, language=language, confidence=getattr(response, "confidence", None), segments=segments, ) ================================================ FILE: aisuite/providers/sambanova_provider.py ================================================ import os from aisuite.provider import Provider, LLMError from openai import OpenAI from aisuite.providers.message_converter import OpenAICompliantMessageConverter class SambanovaMessageConverter(OpenAICompliantMessageConverter): """ SambaNova-specific message converter. """ pass class SambanovaProvider(Provider): """ SambaNova Provider using OpenAI client for API calls. """ def __init__(self, **config): """ Initialize the SambaNova provider with the given configuration. Pass the entire configuration dictionary to the OpenAI client constructor. """ # Ensure API key is provided either in config or via environment variable self.api_key = config.get("api_key", os.getenv("SAMBANOVA_API_KEY")) if not self.api_key: raise ValueError( "Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." ) config["api_key"] = self.api_key config["base_url"] = "https://api.sambanova.ai/v1/" # Pass the entire config to the OpenAI client constructor self.client = OpenAI(**config) self.transformer = SambanovaMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the SambaNova chat completions endpoint using the OpenAI client. """ try: # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) response = self.client.chat.completions.create( model=model, messages=transformed_messages, **kwargs, # Pass any additional arguments to the Sambanova API ) return self.transformer.convert_response(response.model_dump()) except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/providers/together_provider.py ================================================ import os import httpx from aisuite.provider import Provider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter class TogetherMessageConverter(OpenAICompliantMessageConverter): """ Together-specific message converter if needed """ pass class TogetherProvider(Provider): """ Together AI Provider using httpx for direct API calls. """ BASE_URL = "https://api.together.xyz/v1/chat/completions" def __init__(self, **config): """ Initialize the Together provider with the given configuration. The API key is fetched from the config or environment variables. """ self.api_key = config.get("api_key", os.getenv("TOGETHER_API_KEY")) if not self.api_key: raise ValueError( "Together API key is missing. Please provide it in the config or set the TOGETHER_API_KEY environment variable." ) # Optionally set a custom timeout (default to 30s) self.timeout = config.get("timeout", 30) self.transformer = TogetherMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the Together AI chat completions endpoint using httpx. """ # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } data = { "model": model, "messages": transformed_messages, **kwargs, # Pass any additional arguments to the API } try: # Make the request to Together AI endpoint. response = httpx.post( self.BASE_URL, json=data, headers=headers, timeout=self.timeout ) response.raise_for_status() return self.transformer.convert_response(response.json()) except httpx.HTTPStatusError as http_err: raise LLMError(f"Together AI request failed: {http_err}") except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/providers/watsonx_provider.py ================================================ from aisuite.provider import Provider import os from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import ModelInference from aisuite.framework import ChatCompletionResponse class WatsonxProvider(Provider): def __init__(self, **config): self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL") self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY") self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID") if not self.service_url or not self.api_key or not self.project_id: raise EnvironmentError( "Missing one or more required WatsonX environment variables: " "WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. " "Please refer to the setup guide: /guides/watsonx.md." ) def chat_completions_create(self, model, messages, **kwargs): model = ModelInference( model_id=model, credentials=Credentials( api_key=self.api_key, url=self.service_url, ), project_id=self.project_id, ) res = model.chat(messages=messages, params=kwargs) return self.normalize_response(res) def normalize_response(self, response): openai_response = ChatCompletionResponse() openai_response.choices[0].message.content = response["choices"][0]["message"][ "content" ] return openai_response ================================================ FILE: aisuite/providers/xai_provider.py ================================================ import os import httpx from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse from aisuite.providers.message_converter import OpenAICompliantMessageConverter class XaiMessageConverter(OpenAICompliantMessageConverter): """ xAI-specific message converter if needed """ pass class XaiProvider(Provider): """ xAI Provider using httpx for direct API calls. """ BASE_URL = "https://api.x.ai/v1/chat/completions" def __init__(self, **config): """ Initialize the xAI provider with the given configuration. The API key is fetched from the config or environment variables. """ self.api_key = config.get("api_key", os.getenv("XAI_API_KEY")) if not self.api_key: raise ValueError( "xAI API key is missing. Please provide it in the config or set the XAI_API_KEY environment variable." ) # Optionally set a custom timeout (default to 30s) self.timeout = config.get("timeout", 30) self.transformer = XaiMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the xAI chat completions endpoint using httpx. """ # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } data = { "model": model, "messages": transformed_messages, **kwargs, # Pass any additional arguments to the API } try: # Make the request to xAI endpoint. response = httpx.post( self.BASE_URL, json=data, headers=headers, timeout=self.timeout ) response.raise_for_status() return self.transformer.convert_response(response.json()) except httpx.HTTPStatusError as http_err: raise LLMError(f"xAI request failed: {http_err}") except Exception as e: raise LLMError(f"An error occurred: {e}") ================================================ FILE: aisuite/utils/tools.py ================================================ from typing import Callable, Dict, Any, Type, Optional, get_origin, get_args, Union from pydantic import BaseModel, create_model, Field, ValidationError import inspect import json from docstring_parser import parse class Tools: def __init__(self, tools: list[Callable] = None): self._tools = {} if tools: for tool in tools: self._add_tool(tool) # Add a tool function with or without a Pydantic model. def _add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = None): """Register a tool function with metadata. If no param_model is provided, infer from function signature.""" # Check if this is an MCP tool with original schema if hasattr(func, "__mcp_input_schema__") and func.__mcp_input_schema__: # Use the original MCP schema directly to preserve all JSON Schema details tool_spec = self._convert_mcp_schema_to_tool_spec(func) # Create Pydantic model from MCP schema for validation param_model = self._create_pydantic_model_from_mcp_schema(func) elif param_model: tool_spec = self._convert_to_tool_spec(func, param_model) else: tool_spec, param_model = self.__infer_from_signature(func) self._tools[func.__name__] = { "function": func, "param_model": param_model, "spec": tool_spec, } # Return tools in the specified format (default OpenAI). def tools(self, format="openai") -> list: """Return tools in the specified format (default OpenAI).""" if format == "openai": return self.__convert_to_openai_format() return [tool["spec"] for tool in self._tools.values()] def _unwrap_optional(self, field_type: Type) -> tuple[Type, bool]: """ Unwrap Optional[T] to get the base type T. Returns: tuple: (base_type, is_optional) """ # Check if it's Optional (Union with None) origin = get_origin(field_type) if origin is Union: args = get_args(field_type) # Optional[T] is Union[T, None] if type(None) in args: # Get the non-None type non_none_types = [arg for arg in args if arg is not type(None)] if len(non_none_types) == 1: return non_none_types[0], True return field_type, False # Convert the function and its Pydantic model to a unified tool specification. def _convert_to_tool_spec( self, func: Callable, param_model: Type[BaseModel] ) -> Dict[str, Any]: """Convert the function and its Pydantic model to a unified tool specification.""" type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean"} properties = {} for field_name, field in param_model.model_fields.items(): field_type = field.annotation # Unwrap Optional[T] to get base type T field_type, is_optional = self._unwrap_optional(field_type) # Handle enum types if hasattr(field_type, "__members__"): # Check if it's an enum enum_values = [ member.value if hasattr(member, "value") else member.name for member in field_type ] properties[field_name] = { "type": "string", "enum": enum_values, "description": field.description or "", } # Convert enum default value to string if it exists if str(field.default) != "PydanticUndefined": properties[field_name]["default"] = ( field.default.value if hasattr(field.default, "value") else field.default ) else: properties[field_name] = { "type": type_mapping.get(field_type, str(field_type)), "description": field.description or "", } # Add default if it exists and isn't PydanticUndefined if str(field.default) != "PydanticUndefined": properties[field_name]["default"] = field.default return { "name": func.__name__, "description": func.__doc__ or "", "parameters": { "type": "object", "properties": properties, "required": [ name for name, field in param_model.model_fields.items() if field.is_required and str(field.default) == "PydanticUndefined" ], }, } def __extract_param_descriptions(self, func: Callable) -> dict[str, str]: """Extract parameter descriptions from function docstring. Args: func: The function to extract parameter descriptions from Returns: Dictionary mapping parameter names to their descriptions """ docstring = inspect.getdoc(func) or "" parsed_docstring = parse(docstring) param_descriptions = {} for param in parsed_docstring.params: param_descriptions[param.arg_name] = param.description or "" return param_descriptions def _convert_mcp_schema_to_tool_spec(self, func: Callable) -> Dict[str, Any]: """ Convert MCP tool with original inputSchema to tool spec. This preserves the original JSON Schema from MCP without round-trip conversion, avoiding information loss for complex types like arrays and nested objects. Args: func: MCP tool wrapper with __mcp_input_schema__ attribute Returns: Tool specification compatible with OpenAI format """ input_schema = func.__mcp_input_schema__ return { "name": func.__name__, "description": func.__doc__ or "", "parameters": input_schema, # Use original schema directly! } def _create_pydantic_model_from_mcp_schema(self, func: Callable) -> Type[BaseModel]: """ Create a Pydantic model from MCP inputSchema for parameter validation. This is needed for the execute() method to validate tool call arguments. Args: func: MCP tool wrapper with __mcp_input_schema__ attribute Returns: Pydantic model for parameter validation """ from ..mcp.schema_converter import mcp_schema_to_annotations input_schema = func.__mcp_input_schema__ properties = input_schema.get("properties", {}) required = input_schema.get("required", []) # Get type annotations from MCP schema annotations = mcp_schema_to_annotations(input_schema) fields = {} for param_name, param_type in annotations.items(): param_schema = properties.get(param_name, {}) description = param_schema.get("description", "") if param_name in required: fields[param_name] = (param_type, Field(..., description=description)) else: fields[param_name] = ( param_type, Field(default=None, description=description), ) return create_model(f"{func.__name__.capitalize()}Params", **fields) def __infer_from_signature( self, func: Callable ) -> tuple[Dict[str, Any], Type[BaseModel]]: """Infer parameters(required and optional) and requirements directly from the function signature.""" signature = inspect.signature(func) fields = {} required_fields = [] # Get function's docstring and parse parameter descriptions param_descriptions = self.__extract_param_descriptions(func) docstring = inspect.getdoc(func) or "" # Parse the docstring to get the main function description parsed_docstring = parse(docstring) function_description = parsed_docstring.short_description or "" if parsed_docstring.long_description: function_description += "\n\n" + parsed_docstring.long_description for param_name, param in signature.parameters.items(): # Check if a type annotation is missing if param.annotation == inspect._empty: raise TypeError( f"Parameter '{param_name}' in function '{func.__name__}' must have a type annotation." ) # Determine field type and optionality param_type = param.annotation description = param_descriptions.get(param_name, "") if param.default == inspect._empty: fields[param_name] = (param_type, Field(..., description=description)) required_fields.append(param_name) else: fields[param_name] = ( param_type, Field(default=param.default, description=description), ) # Dynamically create a Pydantic model based on inferred fields param_model = create_model(f"{func.__name__.capitalize()}Params", **fields) # Convert inferred model to a tool spec format tool_spec = self._convert_to_tool_spec(func, param_model) # Update the tool spec with the parsed function description instead of raw docstring tool_spec["description"] = function_description return tool_spec, param_model def __convert_to_openai_format(self) -> list: """Convert tools to OpenAI's format.""" return [ {"type": "function", "function": tool["spec"]} for tool in self._tools.values() ] def results_to_messages(self, results: list, message: any) -> list: """Converts results to messages.""" # if message is empty return empty list if not message or len(results) == 0: return [] messages = [] # Iterate over results and match with tool calls from the message for result in results: # Find matching tool call from message.tool_calls for tool_call in message.tool_calls: if tool_call.id == result["tool_call_id"]: messages.append( { "role": "tool", "name": result["name"], "content": json.dumps(result["content"]), "tool_call_id": tool_call.id, } ) break return messages def execute(self, tool_calls) -> list: """Executes registered tools based on the tool calls from the model. Args: tool_calls: List of tool calls from the model Returns: List of results from executing each tool call """ results = [] # Handle single tool call or list of tool calls if not isinstance(tool_calls, list): tool_calls = [tool_calls] for tool_call in tool_calls: # Handle both dictionary and object-style tool calls if isinstance(tool_call, dict): tool_name = tool_call["function"]["name"] arguments = tool_call["function"]["arguments"] else: tool_name = tool_call.function.name arguments = tool_call.function.arguments # Ensure arguments is a dict if isinstance(arguments, str): arguments = json.loads(arguments) if tool_name not in self._tools: raise ValueError(f"Tool '{tool_name}' not registered.") tool = self._tools[tool_name] tool_func = tool["function"] param_model = tool["param_model"] # Validate and parse the arguments with Pydantic if a model exists try: validated_args = param_model(**arguments) result = tool_func(**validated_args.model_dump()) results.append(result) except ValidationError as e: raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") return results def execute_tool(self, tool_calls) -> tuple[list, list]: """Executes registered tools based on the tool calls from the model. Args: tool_calls: List of tool calls from the model Returns: List of tuples containing (result, result_message) for each tool call """ results = [] messages = [] # Handle single tool call or list of tool calls if not isinstance(tool_calls, list): tool_calls = [tool_calls] for tool_call in tool_calls: # Handle both dictionary and object-style tool calls if isinstance(tool_call, dict): tool_name = tool_call["function"]["name"] arguments = tool_call["function"]["arguments"] tool_call_id = tool_call["id"] else: tool_name = tool_call.function.name arguments = tool_call.function.arguments tool_call_id = tool_call.id # Ensure arguments is a dict if isinstance(arguments, str): arguments = json.loads(arguments) if tool_name not in self._tools: raise ValueError(f"Tool '{tool_name}' not registered.") tool = self._tools[tool_name] tool_func = tool["function"] param_model = tool["param_model"] # Validate and parse the arguments with Pydantic if a model exists try: validated_args = param_model(**arguments) result = tool_func(**validated_args.model_dump()) results.append(result) messages.append( { "role": "tool", "name": tool_name, "content": json.dumps(result), "tool_call_id": tool_call_id, } ) except ValidationError as e: raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") return results, messages ================================================ FILE: aisuite/utils/utils.py ================================================ """Utility functions for aisuite.""" import json from unittest.mock import MagicMock from pydantic import BaseModel # pylint: disable=too-few-public-methods class Utils: """ Utility functions for debugging and inspecting objects. """ @staticmethod def spew(obj): """ Recursively inspects a Python object and prints its contents as a nicely formatted JSON string. Handles Pydantic models, nested objects, lists, and circular references. """ visited = set() # pylint: disable=too-many-return-statements def default_encoder(o): # Handle MagicMock objects to prevent circular reference errors in tests if isinstance(o, MagicMock): try: # Attempt to get a descriptive name for the mock # pylint: disable=protected-access name = o._extract_mock_name() # pylint: disable=broad-exception-caught except Exception: name = "unknown" return f'' # Handle other circular references obj_id = id(o) if obj_id in visited: return f"" visited.add(obj_id) # Handle Pydantic models if isinstance(o, BaseModel): return o.model_dump() # Handle general objects by converting their __dict__ if hasattr(o, "__dict__"): return o.__dict__ # Handle sets if isinstance(o, set): return list(o) # Fallback for other types try: return str(o) # pylint: disable=broad-exception-caught except Exception: return f"" print(json.dumps(obj, default=default_encoder, indent=2)) ================================================ FILE: aisuite-js/README.md ================================================ # AISuite AISuite is a unified TypeScript library that provides a single, consistent interface for interacting with multiple Large Language Model (LLM) providers. The library uses OpenAI's API format as the standard interface while supporting OpenAI and Anthropic Claude. npm pacakge - `npm i aisuite` ## Features - **Unified API**: Single interface compatible with OpenAI's API structure - **Multi-Provider Support**: Currently supports OpenAI and Anthropic - **Provider Selection**: Use `provider:model` format (e.g., `openai:gpt-4o`, `anthropic:claude-3-haiku-20240307`) - **Tool Calling**: Transparent tool/function calling across all providers - **Streaming**: Real-time streaming responses with consistent API - **Type Safety**: Full TypeScript support with comprehensive type definitions - **Error Handling**: Unified error handling across providers - **Speech-to-Text**: Automatic Speech Recognition (ASR) support with multiple providers (OpenAI Whisper, Deepgram) ## Installation ```bash npm install aisuite ``` ## Quick Start ```typescript import { Client } from 'aisuite'; const client = new Client({ openai: { apiKey: process.env.OPENAI_API_KEY, }, anthropic: { apiKey: process.env.ANTHROPIC_API_KEY }, deepgram: { apiKey: process.env.DEEPGRAM_API_KEY }, }); // Use any provider with identical interface const response = await client.chat.completions.create({ model: 'openai:gpt-4o', messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'Hello!' } ], }); console.log(response.choices[0].message.content); ``` ## Usage Examples ### Basic Chat Completion ```typescript // OpenAI const openaiResponse = await client.chat.completions.create({ model: 'openai:gpt-4o', messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'What is TypeScript?' } ], temperature: 0.7, max_tokens: 1000, }); // Anthropic - exact same interface const anthropicResponse = await client.chat.completions.create({ model: 'anthropic:claude-3-haiku-20240307', messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'What is TypeScript?' } ], temperature: 0.7, max_tokens: 1000, }); ``` ### Tool/Function Calling ```typescript const tools = [ { type: 'function' as const, function: { name: 'get_weather', description: 'Get current weather for a location', parameters: { type: 'object', properties: { location: { type: 'string', description: 'City name' } }, required: ['location'] } } } ]; // Works identically across all providers const response = await client.chat.completions.create({ model: 'anthropic:claude-3-haiku-20240307', messages: [{ role: 'user', content: 'What\'s the weather in NYC?' }], tools, tool_choice: 'auto' }); if (response.choices[0].message.tool_calls) { console.log('Tool calls:', response.choices[0].message.tool_calls); } ``` ### Streaming Responses ```typescript const stream = await client.chat.completions.create({ model: 'openai:gpt-4o', messages: [{ role: 'user', content: 'Tell me a story' }], stream: true }); // TypeScript: cast to AsyncIterable for await (const chunk of stream as AsyncIterable) { process.stdout.write(chunk.choices[0]?.delta?.content || ''); } ``` ### Streaming with Abort Controller ```typescript const controller = new AbortController(); // Abort after 5 seconds setTimeout(() => controller.abort(), 5000); const stream = await client.chat.completions.create({ model: 'anthropic:claude-3-haiku-20240307', messages: [{ role: 'user', content: 'Write a long story' }], stream: true }, { signal: controller.signal }); try { for await (const chunk of stream as AsyncIterable) { process.stdout.write(chunk.choices[0]?.delta?.content || ''); } } catch (error) { if (error.name === 'AbortError') { console.log('Stream aborted'); } } ``` ### Speech-to-Text Transcription ```typescript // Initialize client with audio support for OpenAI const client = new Client({ openai: { apiKey: process.env.OPENAI_API_KEY, }, deepgram: { apiKey: process.env.DEEPGRAM_API_KEY } }); // Using Deepgram const deepgramResponse = await client.audio.transcriptions.create({ model: 'deepgram:nova-2', file: audioBuffer, // Buffer containing audio data language: 'en-US', timestamps: true, word_confidence: true, speaker_labels: true, }); // Using OpenAI Whisper const openaiResponse = await client.audio.transcriptions.create({ model: 'openai:whisper-1', file: audioBuffer, language: 'en', response_format: 'verbose_json', temperature: 0, timestamps: true, }); console.log('Transcribed Text:', openaiResponse.text); console.log('Words with timestamps:', openaiResponse.words); ``` ### Error Handling ```typescript import { AISuiteError, ProviderNotConfiguredError } from 'aisuite'; try { const response = await client.chat.completions.create({ model: 'invalid:model', messages: [{ role: 'user', content: 'Hello' }] }); } catch (error) { if (error instanceof ProviderNotConfiguredError) { console.error('Provider not configured:', error.message); } else if (error instanceof AISuiteError) { console.error('AISuite error:', error.message, error.provider); } else { console.error('Unknown error:', error); } } ``` ## API Reference ### Client Configuration ```typescript const client = new Client({ openai?: { apiKey: string; baseURL?: string; organization?: string; }, anthropic?: { apiKey: string; baseURL?: string; }, deepgram?: { apiKey: string; baseURL?: string; } }); ``` ### Chat Completion Request All providers use the standard OpenAI chat completion format: ```typescript interface ChatCompletionRequest { model: string; // "provider:model" format messages: ChatMessage[]; tools?: Tool[]; tool_choice?: ToolChoice; temperature?: number; max_tokens?: number; stop?: string | string[]; stream?: boolean; } ``` ### Transcription Request All ASR providers use a standard transcription request format with additional provider-specific parameters: ```typescript interface TranscriptionRequest { model: string; // "provider:model" format file: Buffer; // Audio file as Buffer language?: string; // Language code (e.g., "en", "en-US") timestamps?: boolean; // Include word-level timestamps [key: string]: any; // Additional provider-specific parameters: // For OpenAI: See https://platform.openai.com/docs/api-reference/audio/createTranscription // For Deepgram: See https://developers.deepgram.com/reference/speech-to-text-api/listen } ``` ### Helper Methods ```typescript // List all configured providers (including ASR) client.listProviders(); // ['openai', 'anthropic'] client.listASRProviders(); // ['deepgram', 'openai'] // Check if a provider is configured client.isProviderConfigured('openai'); // true client.isASRProviderConfigured('deepgram'); // true ``` ## Current Limitations - Only OpenAI and Anthropic providers are currently supported for chat (Gemini, Mistral, and Bedrock coming soon) - Tool calling requires handling tool responses manually - Streaming tool calls require manual accumulation of arguments - ASR support is limited to OpenAI Whisper (requires explicit audio configuration) and Deepgram - Some provider-specific ASR features might require using provider-specific parameters ## Development ```bash # Install dependencies npm install # Build the project npm run build # Run tests npm test # Run examples #Run basic usage example only: npm run example:basic # Run tool calling example only: npm run example:tools # Run the full test suite: npm run test:examples ``` ## License MIT ================================================ FILE: aisuite-js/examples/basic-usage.ts ================================================ import 'dotenv/config'; import { Client } from '../src'; async function main() { // Initialize the client with API keys const client = new Client({ openai: { apiKey: process.env.OPENAI_API_KEY! }, anthropic: { apiKey: process.env.ANTHROPIC_API_KEY! }, }); console.log('Available providers:', client.listProviders()); // Example 1: OpenAI Chat Completion console.log('\n--- OpenAI Example ---'); try { const openaiResponse = await client.chat.completions.create({ model: 'openai:gpt-4o-mini', messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'What is TypeScript in one sentence?' } ], temperature: 0.7, max_tokens: 100, }); console.log('OpenAI Response:', openaiResponse.choices[0].message.content); console.log('Usage:', openaiResponse.usage); console.log('Full response:', JSON.stringify(openaiResponse, null, 2)); } catch (error) { console.error('OpenAI Error:', error); } // Example 2: Anthropic Chat Completion console.log('\n--- Anthropic Example ---'); try { const anthropicResponse = await client.chat.completions.create({ model: 'anthropic:claude-3-haiku-20240307', messages: [ { role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: 'What is TypeScript in one sentence?' } ], temperature: 0.7, max_tokens: 100, }); console.log('Anthropic Response:', anthropicResponse.choices[0].message.content); console.log('Usage:', anthropicResponse.usage); console.log('Full response:', JSON.stringify(anthropicResponse, null, 2)); } catch (error) { console.error('Anthropic Error:', error); } // Example 3: Error handling - invalid provider console.log('\n--- Error Handling Example ---'); try { await client.chat.completions.create({ model: 'invalid:model', messages: [{ role: 'user', content: 'Hello' }] }); } catch (error) { console.error('Expected error:', error); } } // Run the examples main().catch(console.error); ================================================ FILE: aisuite-js/examples/chat-app/.eslintrc.cjs ================================================ module.exports = { root: true, env: { browser: true, es2020: true }, extends: [ 'eslint:recommended', '@typescript-eslint/recommended', 'plugin:react-hooks/recommended', ], ignorePatterns: ['dist', '.eslintrc.cjs'], parser: '@typescript-eslint/parser', plugins: ['react-refresh'], rules: { 'react-refresh/only-export-components': [ 'warn', { allowConstantExport: true }, ], }, } ================================================ FILE: aisuite-js/examples/chat-app/.gitignore ================================================ # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* lerna-debug.log* node_modules dist dist-ssr *.local # Editor directories and files .vscode/* !.vscode/extensions.json .idea .DS_Store *.suo *.ntvs* *.njsproj *.sln *.sw? # Environment variables .env .env.local .env.development.local .env.test.local .env.production.local ================================================ FILE: aisuite-js/examples/chat-app/README.md ================================================ # AISuite Chat App A modern React TypeScript chat application built with AISuite, allowing you to chat with multiple AI models and compare their responses in real-time. ## Features - **Multi-Provider Support**: Chat with OpenAI, Anthropic, Groq, and Mistral models - **Comparison Mode**: Compare responses from two different AI models side-by-side - **Modern UI**: Clean, responsive interface built with React and Tailwind CSS - **Real-time Chat**: Instant messaging with AI models - **API Key Management**: Secure storage and management of API keys - **Error Handling**: Comprehensive error handling and user feedback - **TypeScript**: Full type safety throughout the application ## Prerequisites - Node.js 18+ - npm or yarn - API keys for the AI providers you want to use: - OpenAI API key - Anthropic API key - Groq API key - Mistral API key ## Installation 1. Clone the repository and navigate to the chat app directory: ```bash cd aisuite-js/chat-app ``` 2. Install dependencies: ```bash npm install ``` 3. Start the development server: ```bash npm run dev ``` 4. Open your browser and navigate to `http://localhost:3000` ## Configuration ### API Keys 1. Click the "Configure API Keys" button in the header 2. Enter your API keys for the providers you want to use 3. Click "Save" to store the configuration The app will automatically save your API keys to localStorage for future use. ### Supported Models The app comes pre-configured with the following models: **OpenAI:** - GPT-4o - GPT-4o Mini **Anthropic:** - Claude 3.5 Sonnet - Claude 3 Haiku **Groq:** - Llama 3.1 8B - Mixtral 8x7B **Mistral:** - Mistral 7B - Mistral Large ## Usage ### Basic Chat 1. Configure your API keys 2. Select a model from the dropdown 3. Type your message and press Enter or click Send 4. View the AI response ### Comparison Mode 1. Enable "Comparison Mode" checkbox 2. Select two different models 3. Send a message to see responses from both models side-by-side 4. Compare the different responses and capabilities ### Chat Management - **Reset Chat**: Click the reset button to clear all chat history - **Model Switching**: Change models at any time during the conversation - **Error Handling**: The app displays clear error messages for API issues ## Sample Queries Try these sample queries to test the different models: ``` "What is the weather in Tokyo?" ``` ``` "Write a poem about the weather in Tokyo." ``` ``` "Write a python program to print the fibonacci sequence." ``` ``` "Write test cases for this program." ``` ## Development ### Project Structure ``` src/ ├── components/ # React components │ ├── ApiKeyModal.tsx │ ├── ChatContainer.tsx │ ├── ChatInput.tsx │ ├── ChatMessage.tsx │ └── ModelSelector.tsx ├── config/ # Configuration files │ └── llm-config.ts ├── services/ # Business logic │ └── aisuite-service.ts ├── types/ # TypeScript type definitions │ └── chat.ts ├── App.tsx # Main application component ├── main.tsx # Application entry point └── index.css # Global styles ``` ### Available Scripts - `npm run dev` - Start development server - `npm run build` - Build for production - `npm run preview` - Preview production build - `npm run lint` - Run ESLint ### Adding New Models To add new models, edit `src/config/llm-config.ts`: ```typescript export const configuredLLMs: LLMConfig[] = [ // ... existing models { name: "Your New Model", provider: "provider-name", model: "model-name" } ]; ``` ### Styling The app uses Tailwind CSS for styling. The design system includes: - Light and dark mode support - Responsive design - Custom scrollbars - Loading animations - Error states ## Technologies Used - **React 18** - UI framework - **TypeScript** - Type safety - **Vite** - Build tool and dev server - **Tailwind CSS** - Styling - **Lucide React** - Icons - **AISuite** - AI provider abstraction ## Browser Support - Chrome 90+ - Firefox 88+ - Safari 14+ - Edge 90+ ## Contributing 1. Fork the repository 2. Create a feature branch 3. Make your changes 4. Add tests if applicable 5. Submit a pull request ## License MIT License - see the main repository for details. ## Support For issues and questions: - Check the [AISuite documentation](https://github.com/andrewyng/aisuite) - Open an issue in the repository - Check the console for error messages ## Security Notes - API keys are stored in localStorage (client-side only) - No API keys are sent to any server except the AI providers - Consider using environment variables for production deployments ================================================ FILE: aisuite-js/examples/chat-app/index.html ================================================ AISuite Chat App
================================================ FILE: aisuite-js/examples/chat-app/package.json ================================================ { "name": "aisuite-chat-app", "version": "1.0.0", "description": "A React TypeScript chat application using AISuite", "private": true, "scripts": { "dev": "vite", "build": "tsc && vite build", "preview": "vite preview", "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0" }, "dependencies": { "react": "^18.2.0", "react-dom": "^18.2.0", "lucide-react": "^0.263.1", "clsx": "^2.0.0", "tailwind-merge": "^1.14.0" }, "devDependencies": { "@types/react": "^18.2.15", "@types/react-dom": "^18.2.7", "@typescript-eslint/eslint-plugin": "^6.0.0", "@typescript-eslint/parser": "^6.0.0", "@vitejs/plugin-react": "^4.0.3", "autoprefixer": "^10.4.14", "eslint": "^8.45.0", "eslint-plugin-react-hooks": "^4.6.0", "eslint-plugin-react-refresh": "^0.4.3", "postcss": "^8.4.27", "tailwindcss": "^3.3.3", "typescript": "^5.0.2", "vite": "^4.4.5" } } ================================================ FILE: aisuite-js/examples/chat-app/postcss.config.js ================================================ export default { plugins: { tailwindcss: {}, autoprefixer: {}, }, } ================================================ FILE: aisuite-js/examples/chat-app/src/App.tsx ================================================ import React, { useState, useEffect } from 'react'; import { Settings, AlertCircle } from 'lucide-react'; import { Message, AISuiteConfig } from './types/chat'; import { configuredLLMs, getLLMConfigByName } from './config/llm-config'; import { aiSuiteService } from './services/aisuite-service'; import { ChatContainer } from './components/ChatContainer'; import { ChatInput } from './components/ChatInput'; import { ModelSelector } from './components/ModelSelector'; import { ProviderSelector } from './components/ProviderSelector'; import { ApiKeyModal } from './components/ApiKeyModal'; function App() { const [chatHistory1, setChatHistory1] = useState([]); const [chatHistory2, setChatHistory2] = useState([]); const [isProcessing, setIsProcessing] = useState(false); const [useComparisonMode, setUseComparisonMode] = useState(false); const [selectedProvider, setSelectedProvider] = useState(''); const [selectedModel1, setSelectedModel1] = useState(''); const [selectedModel2, setSelectedModel2] = useState(''); const [showApiKeyModal, setShowApiKeyModal] = useState(false); const [apiConfig, setApiConfig] = useState({}); const [error, setError] = useState(null); // Initialize AISuite service when API config changes useEffect(() => { if (Object.keys(apiConfig).length > 0) { try { aiSuiteService.initialize(apiConfig); setError(null); } catch (err) { setError('Failed to initialize AISuite client'); } } }, [apiConfig]); // Load API config from localStorage on mount useEffect(() => { const savedConfig = localStorage.getItem('aisuite-config'); if (savedConfig) { try { const config = JSON.parse(savedConfig); setApiConfig(config); } catch (err) { console.error('Failed to load saved config'); } } }, []); const handleSendMessage = async (message: string) => { if (!message.trim()) return; // Check if provider is selected if (!selectedProvider) { setError('Please select a provider first'); return; } // Check if API key is configured for the selected provider if (!apiConfig[selectedProvider as keyof AISuiteConfig]?.apiKey) { setError(`API key for ${selectedProvider} is not configured. Please configure it first.`); setShowApiKeyModal(true); return; } const userMessage: Message = { role: 'user', content: message, timestamp: new Date() }; setIsProcessing(true); setError(null); try { // Add user message to both chat histories setChatHistory1(prev => [...prev, userMessage]); if (useComparisonMode) { setChatHistory2(prev => [...prev, userMessage]); } // Get model configurations const modelConfig1 = getLLMConfigByName(selectedModel1); if (!modelConfig1) { throw new Error(`Model ${selectedModel1} not found`); } // Query first model const response1 = await aiSuiteService.queryLLM(modelConfig1, [...chatHistory1, userMessage]); const assistantMessage1: Message = { role: 'assistant', content: response1, timestamp: new Date() }; setChatHistory1(prev => [...prev, assistantMessage1]); // Query second model if in comparison mode if (useComparisonMode) { const modelConfig2 = getLLMConfigByName(selectedModel2); if (!modelConfig2) { throw new Error(`Model ${selectedModel2} not found`); } const response2 = await aiSuiteService.queryLLM(modelConfig2, [...chatHistory2, userMessage]); const assistantMessage2: Message = { role: 'assistant', content: response2, timestamp: new Date() }; setChatHistory2(prev => [...prev, assistantMessage2]); } } catch (err) { setError(err instanceof Error ? err.message : 'An error occurred'); } finally { setIsProcessing(false); } }; const handleResetChat = () => { setChatHistory1([]); setChatHistory2([]); setError(null); }; const handleSaveApiConfig = (config: AISuiteConfig) => { setApiConfig(config); localStorage.setItem('aisuite-config', JSON.stringify(config)); }; // Get all available providers (show all by default) const allProviders = ['openai', 'anthropic', 'groq', 'mistral']; const availableProviders = allProviders; // Get configured providers (those with API keys) const configuredProviders = Object.keys(apiConfig).filter(provider => apiConfig[provider as keyof AISuiteConfig]?.apiKey ); // Get models for the selected provider const availableModels = selectedProvider ? configuredLLMs.filter(model => model.provider === selectedProvider) : []; // Reset model selections when provider changes useEffect(() => { if (selectedProvider) { const providerModels = configuredLLMs.filter(model => model.provider === selectedProvider); if (providerModels.length > 0) { setSelectedModel1(providerModels[0].name); if (useComparisonMode && providerModels.length > 1) { setSelectedModel2(providerModels[1].name); } else { setSelectedModel2(''); } } else { setSelectedModel1(''); setSelectedModel2(''); } } else { setSelectedModel1(''); setSelectedModel2(''); } }, [selectedProvider, useComparisonMode]); const hasConfiguredProviders = Object.keys(apiConfig).length > 0; return (
{/* Header */}

AISuite Chat

{/* Main Content */}
{!hasConfiguredProviders ? (

No API Keys Configured

Please configure your API keys to start chatting with AI models.

) : (
{/* Error Display */} {error && (
{error}
)} {/* Controls */}
{/* Provider Selection */}
{/* Model Selection - Only show if provider is selected */} {selectedProvider && (
{useComparisonMode && availableModels.length > 1 && ( )}
)}
{/* Chat Containers */} {selectedProvider && selectedModel1 && (

{selectedModel1}

{useComparisonMode && selectedModel2 && (

{selectedModel2}

)}
)} {/* No Provider Selected State */} {!selectedProvider && hasConfiguredProviders && (

Select a Provider

Please select an AI provider to start chatting.

)} {/* Chat Input */}
)}
{/* API Key Modal */} setShowApiKeyModal(false)} onSave={handleSaveApiConfig} initialConfig={apiConfig} />
); } export default App; ================================================ FILE: aisuite-js/examples/chat-app/src/components/ApiKeyModal.tsx ================================================ import React, { useState } from 'react'; import { X, Eye, EyeOff } from 'lucide-react'; import { AISuiteConfig } from '../types/chat'; interface ApiKeyModalProps { isOpen: boolean; onClose: () => void; onSave: (config: AISuiteConfig) => void; initialConfig?: AISuiteConfig; } export const ApiKeyModal: React.FC = ({ isOpen, onClose, onSave, initialConfig = {} }) => { const [config, setConfig] = useState(initialConfig); const [showKeys, setShowKeys] = useState>({}); const toggleKeyVisibility = (provider: string) => { setShowKeys(prev => ({ ...prev, [provider]: !prev[provider] })); }; const handleSave = () => { // Filter out empty API keys const filteredConfig: AISuiteConfig = {}; Object.entries(config).forEach(([provider, providerConfig]) => { if (providerConfig?.apiKey?.trim()) { providerConfig.dangerouslyAllowBrowser = true; filteredConfig[provider as keyof AISuiteConfig] = providerConfig; } }); onSave(filteredConfig); onClose(); }; const updateConfig = (provider: string, field: string, value: string) => { setConfig(prev => ({ ...prev, [provider]: { ...prev[provider as keyof AISuiteConfig], [field]: value } })); }; if (!isOpen) return null; return (

Configure API Keys

{/* OpenAI */}
updateConfig('openai', 'apiKey', e.target.value)} className="w-full rounded-lg border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2" />
{/* Anthropic */}
updateConfig('anthropic', 'apiKey', e.target.value)} className="w-full rounded-lg border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2" />
{/* Groq */}
updateConfig('groq', 'apiKey', e.target.value)} className="w-full rounded-lg border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2" />
{/* Mistral */}
updateConfig('mistral', 'apiKey', e.target.value)} className="w-full rounded-lg border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2" />
); }; ================================================ FILE: aisuite-js/examples/chat-app/src/components/ChatContainer.tsx ================================================ import React, { useRef, useEffect } from 'react'; import { Message } from '../types/chat'; import { ChatMessage } from './ChatMessage'; interface ChatContainerProps { messages: Message[]; modelName: string; isLoading?: boolean; } export const ChatContainer: React.FC = ({ messages, modelName, isLoading = false }) => { const messagesEndRef = useRef(null); const scrollToBottom = () => { messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); }; useEffect(() => { scrollToBottom(); }, [messages]); return (
{messages.length === 0 ? (
No messages yet
Start a conversation with {modelName}
) : ( messages.map((message, index) => ( )) )} {isLoading && (
{modelName}
Thinking...
)}
); }; ================================================ FILE: aisuite-js/examples/chat-app/src/components/ChatInput.tsx ================================================ import React, { useState, KeyboardEvent } from 'react'; import { Send, RotateCcw } from 'lucide-react'; interface ChatInputProps { onSendMessage: (message: string) => void; onResetChat: () => void; isLoading: boolean; placeholder?: string; disabled?: boolean; } export const ChatInput: React.FC = ({ onSendMessage, onResetChat, isLoading, placeholder = "Enter your query...", disabled = false }) => { const [message, setMessage] = useState(''); const handleSend = () => { if (message.trim() && !isLoading && !disabled) { onSendMessage(message.trim()); setMessage(''); } }; const handleKeyPress = (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleSend(); } }; return (