Repository: guywaldman/orch Branch: main Commit: b9149a83eb07 Files: 86 Total size: 153.5 KB Directory structure: gitextract_rb3yefpi/ ├── .github/ │ └── workflows/ │ ├── build.yml │ └── release.yml ├── .gitignore ├── .rusfmt.toml ├── .vscode/ │ └── settings.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE.md ├── README.md ├── RELEASE.md ├── core/ │ ├── Cargo.toml │ ├── README.md │ ├── examples/ │ │ ├── alignment.rs │ │ ├── embeddings.rs │ │ ├── example_utils.rs │ │ ├── structured_data_generation_blog.rs │ │ ├── structured_data_generation_capital.rs │ │ ├── text_generation.rs │ │ ├── text_generation_stream.rs │ │ └── variants_derive.rs │ └── src/ │ ├── alignment/ │ │ ├── mod.rs │ │ ├── strategy.rs │ │ └── strategy_builder.rs │ ├── execution/ │ │ ├── builder.rs │ │ ├── executor.rs │ │ ├── mod.rs │ │ ├── response.rs │ │ ├── structured_executor.rs │ │ └── text_executor.rs │ ├── lib.rs │ ├── lm/ │ │ ├── builder.rs │ │ ├── error.rs │ │ ├── lm_provider/ │ │ │ ├── anthropic/ │ │ │ │ ├── builder.rs │ │ │ │ ├── client/ │ │ │ │ │ ├── anthropic_client.rs │ │ │ │ │ ├── builder.rs │ │ │ │ │ ├── config.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── models.rs │ │ │ │ ├── lm.rs │ │ │ │ ├── mod.rs │ │ │ │ └── models.rs │ │ │ ├── mod.rs │ │ │ ├── models.rs │ │ │ ├── ollama/ │ │ │ │ ├── builder.rs │ │ │ │ ├── config.rs │ │ │ │ ├── lm.rs │ │ │ │ ├── mod.rs │ │ │ │ └── models.rs │ │ │ └── openai/ │ │ │ ├── builder.rs │ │ │ ├── config.rs │ │ │ ├── lm.rs │ │ │ ├── mod.rs │ │ │ └── models.rs │ │ ├── mod.rs │ │ └── models.rs │ ├── net/ │ │ ├── mod.rs │ │ └── sse.rs │ └── response.rs ├── orch.code-workspace ├── response/ │ ├── .gitignore │ ├── Cargo.toml │ └── src/ │ └── lib.rs ├── response_derive/ │ ├── .rustfmt.toml │ ├── Cargo.toml │ ├── README.md │ └── src/ │ ├── attribute_impl.rs │ ├── derive_impl.rs │ └── lib.rs ├── scripts/ │ ├── ci.sh │ ├── examples.sh │ └── utils.sh └── src/ ├── core/ │ ├── mod.rs │ └── net/ │ ├── mod.rs │ └── sse.rs ├── executor.rs ├── lib.rs └── llm/ ├── error.rs ├── llm_provider/ │ ├── mod.rs │ ├── ollama/ │ │ ├── config.rs │ │ ├── llm.rs │ │ ├── mod.rs │ │ └── models.rs │ └── openai.rs ├── mod.rs └── models.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/build.yml ================================================ name: Build on: pull_request: branches: [main] push: branches: [develop] jobs: check: name: Check runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v2 - name: Install stable toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true - uses: Swatinem/rust-cache@v1 - name: Run cargo check uses: actions-rs/cargo@v1 with: command: check test: name: Test Suite strategy: matrix: # TODO: #7 Add tests for Windows & macOS. # os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest] rust: [stable] runs-on: ${{ matrix.os }} steps: - name: Checkout sources uses: actions/checkout@v2 - name: Install stable toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust }} override: true - uses: Swatinem/rust-cache@v1 - name: Unit tests uses: actions-rs/cargo@v1 with: command: test - name: Doc tests uses: actions-rs/cargo@v1 with: command: test args: --doc # TODO: Test with Ollama. # - name: Install and run ollama # run: | # ./scripts/ci.sh # - name: Examples (Ollama) # uses: actions-rs/cargo@v1 # with: # command: test # args: -p core --examples -- ollama - name: Examples (OpenAI) run: ./scripts/examples.sh openai env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY_E2E }} - name: Examples (Anthropic) run: ./scripts/examples.sh anthropic env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY_E2E }} lints: name: Lints runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v2 with: submodules: true - name: Install stable toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true components: rustfmt, clippy - uses: Swatinem/rust-cache@v1 - name: Run cargo fmt uses: actions-rs/cargo@v1 with: command: fmt args: --all -- --check - name: Run cargo clippy uses: actions-rs/cargo@v1 with: command: clippy args: -- -D warnings # - name: Run rustdoc lints # uses: actions-rs/cargo@v1 # env: # RUSTDOCFLAGS: "-D missing_docs -D rustdoc::missing_doc_code_examples" # with: # command: doc # args: --workspace --all-features --no-deps --document-private-items ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: tags: ["[0-9]+.[0-9]+.[0-9]+"] env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} jobs: release: name: Release runs-on: ubuntu-latest steps: - name: Exit if not on main branch if: endsWith(github.event.base_ref, 'main') == false run: exit 1 - name: Checkout sources uses: actions/checkout@v2 - name: Install stable toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable override: true - name: Publish `orch_response` run: cargo publish -p orch_response --token ${{ secrets.CRATES_IO_API_TOKEN }} - name: Publish `orch_response_derive` run: cargo publish -p orch_response_derive --token ${{ secrets.CRATES_IO_API_TOKEN }} - name: Publish `orch` run: cargo publish -p orch --token ${{ secrets.CRATES_IO_API_TOKEN }} ================================================ FILE: .gitignore ================================================ # Generated by Cargo # will have compiled files and executables debug/ target/ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk # MSVC Windows builds of rustc generate these, which store debugging information *.pdb .env* .DS_Store ================================================ FILE: .rusfmt.toml ================================================ max_width = 140 ================================================ FILE: .vscode/settings.json ================================================ { "rust-analyzer.linkedProjects": [ "./Cargo.toml" ] } ================================================ FILE: CHANGELOG.md ================================================ # Changelog # Version 0.0.16 - Added ability to configure the OpenAI API endpoint - Added support for Anthropic models - Fixed issue with `is_local` producing the incorrect result - Improved examples for the `orch` crate # Version 0.0.15 - Removed dimensions from the OpenAI embedding model (no such requirements, compared to Ollama) # Version 0.0.14 - Added `LangaugeModelProvider::is_local` method # Version 0.0.13 No functional changes. # Version 0.0.12 - Fixed the proc macro for the variants of the response options ([PR #9](https://github.com/guywaldman/orch/pull/9)) - Added support for "alignment", improved documentation, added examples ([PR #11](https://github.com/guywaldman/orch/pull/11)) # Version 0.0.11 No functional changes. # Version 0.0.10 No functional changes. # Version 0.0.9 No functional changes. # Version 0.0.8 - Added support for boolean fields in the response options # Version 0.0.7 - Fixed issue where the `orch` crate was not used for types in the proc macros - Fixed issue where multiple fields in a response option would fail the proc macro # Version 0.0.6 No functional changes. # Version 0.0.5 - Fixed an issue where the proc macros were not exposed directly from `orch` # Version 0.0.4 No functional changes. # Version 0.0.3 - Added support for streaming responses - Added support for structured data generation - Added a convenience proc macro (`#[derive(OrchResponseOptions)]`) for generating structured data generation options - Added support for Open AI ## Version 0.0.2 No functional changes. ## Version 0.0.1 Initial release. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to `orch` Thank you for your interest in contributing to `orch`! Please follow this process for contributing code: 1. Open an issue to discuss the changes you want to make 1. Wait for maintainers to approve this change > This step is solely to reduce noise and avoid redundant work for implementing a change that isn't accepted. The assumption is that that a maintainer will review the issue in reasonable time. ## Contribution Workflow 1. Fork the respository 1. Implement the change you wish to make 1. Open a pull request that references the relevant issue 1. Make sure that the CI passes 1. Wait for maintainers to review and approve the pull request ## Development Environment ### Prerequisites 1. Rust toolchain and Cargo (MSVC: 1.78.0) ### Workflow ```shell git clone https://github.com/{your-username}/magic-cli cd magic-cli # ...Implement the changes you wish to make cargo fmt cargo test # ...Commit your changes # ...Push your changes # ...Open a pull request ``` ================================================ FILE: Cargo.toml ================================================ [workspace] resolver = "2" members = ["core", "response", "response_derive"] ================================================ FILE: LICENSE.md ================================================ MIT License Copyright (c) [year] [fullname] Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # orch ![Crates.io Version](https://img.shields.io/crates/v/orch?link=https%3A%2F%2Fcrates.io%2Fcrates%2Forch) ![Crates.io Total Downloads](https://img.shields.io/crates/d/orch?link=https%3A%2F%2Fcrates.io%2Fcrates%2Forch) `orch` is a library for building language model powered applications and agents for the Rust programming language. It was primarily built for usage in [magic-cli](https://github.com/guywaldman/magic-cli), but can be used in other contexts as well. > [!NOTE] > > If the project gains traction, this can be compiled as an addon to other languages such as Python or a standalone WebAssembly module. # Installation ```shell cargo add orch cargo add orch_response ``` Alternatively, add `orch` as a dependency to your `Cargo.toml` file: ```toml [dependencies] orch = "*" # Substitute with the latest version orch_response = "*" # Substitute with the latest version ``` # Basic Usage ## Simple Text Generation ```rust no_run use orch::execution::*; use orch::lm::*; #[tokio::main] async fn main() { let lm = OllamaBuilder::new().try_build().unwrap(); let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); let response = executor.execute("What is 2+2?").await.expect("Execution failed"); println!("{}", response.content); } ``` ## Streaming Text Generation ```rust no_run use orch::execution::*; use orch::lm::*; use tokio_stream::StreamExt; #[tokio::main] async fn main() { let lm = OllamaBuilder::new().try_build().unwrap(); let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); let mut response = executor.execute_stream("What is 2+2?").await.expect("Execution failed"); while let Some(chunk) = response.stream.next().await { match chunk { Ok(chunk) => print!("{chunk}"), Err(e) => { println!("Error: {e}"); break; } } } println!(); } ``` ## Structured Data Generation ```rust no_run use orch::execution::*; use orch::lm::*; use orch::response::*; #[derive(Variants, serde::Deserialize)] pub enum ResponseVariants { Answer(AnswerResponseVariant), Fail(FailResponseVariant), } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Answer", scenario = "You know the capital city of the country", description = "Capital city of the country" )] pub struct AnswerResponseVariant { #[schema( description = "Capital city of the received country", example = "London" )] pub capital: String, } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Fail", scenario = "You don't know the capital city of the country", description = "Reason why the capital city is not known" )] pub struct FailResponseVariant { #[schema( description = "Reason why the capital city is not known", example = "Country 'foobar' does not exist" )] pub reason: String, } #[tokio::main] async fn main() { let lm = OllamaBuilder::new().try_build().unwrap(); let executor = StructuredExecutorBuilder::new() .with_lm(&lm) .with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.") .with_options(Box::new(variants!(ResponseVariants))) .try_build() .unwrap(); let response = executor .execute("What is the capital of Fooland?") .await .expect("Execution failed"); println!("Response:"); match response.content { ResponseVariants::Answer(answer) => { println!("Capital city: {}", answer.capital); } ResponseVariants::Fail(fail) => { println!("Model failed to generate a response: {}", fail.reason); } } } ``` ## Embedding Generation ```rust no_run use orch::execution::*; use orch::lm::*; #[tokio::main] async fn main() { let lm = OllamaBuilder::new().try_build().unwrap(); let executor = TextExecutorBuilder::new() .with_lm(&lm) .try_build() .unwrap(); let embedding = executor .generate_embedding("Phrase to generate an embedding for") .await .expect("Execution failed"); println!("Embedding:"); println!("{:?}", embedding); } ``` ## More Examples See the [examples](https://github.com/guywaldman/orch/tree/main/core/examples) directory for usage examples. ================================================ FILE: RELEASE.md ================================================ # Release Process 1. Update the versions in all package `Cargo.toml` files and version in the README.md file > Alternatively, rename all versions that are not in [CHANGELOG.md](CHANGELOG.md) to the next version. 1. Update [CHANGELOG.md](CHANGELOG.md) 1. Push the changes to the `main` branch with a tag (e.g., `0.0.6`) ================================================ FILE: core/Cargo.toml ================================================ [package] name = "orch" version = "0.0.16" edition = "2021" license = "MIT" description = "Language model orchestration library" homepage = "https://github.com/guywaldman/orch" repository = "https://github.com/guywaldman/orch" keywords = ["llm", "openai", "ollama", "rust"] [dependencies] orch_response = { path = "../response", version = "0.0.16" } orch_response_derive = { path = "../response_derive", version = "0.0.16" } async-gen = "0.2.3" dotenv = "0.15.0" openai-api-rs = "5.0.2" reqwest = { version = "0.12.5", features = ["blocking"] } serde = { version = "1.0.164", features = ["derive"] } serde_json = "1.0.97" thiserror = "1.0.63" tokio = { version = "1.28.2", features = ["rt", "macros"] } tokio-stream = "0.1.15" async-trait = "0.1.81" dyn-clone = "1.0.17" async-recursion = "1.1.1" ================================================ FILE: core/README.md ================================================ # orch See the [main README](../README.md) for more information. ================================================ FILE: core/examples/alignment.rs ================================================ #![allow(dead_code)] use orch::alignment::AlignmentStrategyBuilder; use orch::execution::*; use orch::response::*; mod example_utils; use example_utils::get_lm; #[derive(Variants, serde::Deserialize)] pub enum ResponseVariants { Answer(AnswerResponseVariant), Fail(FailResponseVariant), } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Answer", scenario = "You know the answer", description = "Result of the calculation" )] pub struct AnswerResponseVariant { #[schema(description = "Result of the calculation", example = "42")] pub result: String, } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Fail", scenario = "You don't know the answer", description = "Reason why the answer is not known" )] pub struct FailResponseVariant { #[schema( description = "Reason why the answer is not known", example = "The phrase is not a mathematical related expression" )] pub reason: String, } #[tokio::main] async fn main() { let (lm, _) = get_lm(); // In this example, we use the same LLM for the correction as for the main task. // This could be replaced by a smaller LM. let (corrector_lm, _) = get_lm(); // We define an alignment strategy that uses the correction model. let alignment_strategy = AlignmentStrategyBuilder::new() .with_retries(2) .with_lm(&*corrector_lm) .try_build() .unwrap(); let executor = StructuredExecutorBuilder::new() .with_lm(&*lm) .with_preamble(" You are a mathematician who helps users understand the result of mathematical expressions. You will receive a mathematical expression, and you will need to provide the result of that expression. ") .with_options(Box::new(variants!(ResponseVariants))) .with_alignment(alignment_strategy) .try_build() .unwrap(); let response = executor.execute("2 + 2").await.expect("Execution failed"); match response.content { ResponseVariants::Answer(answer) => { println!("Result: {}", answer.result); } ResponseVariants::Fail(fail) => { println!("Model failed to generate a response: {}", fail.reason); } } } ================================================ FILE: core/examples/embeddings.rs ================================================ //! This example demonstrates how to use the `Executor` to generate embeddings from the language model. //! Run like so: `cargo run --example embeddings` mod example_utils; use example_utils::get_lm; use orch::{execution::*, lm::LanguageModelProvider}; #[tokio::main] async fn main() { let (lm, provider) = get_lm(); if provider == LanguageModelProvider::Anthropic { println!("Anthropic does not have built-in embedding models. Skipping example."); return; } let text = "Lorem ipsum"; println!("Text: {text}"); println!("---"); let executor = TextExecutorBuilder::new() .with_lm(&*lm) .try_build() .unwrap(); let embedding = executor .generate_embedding(text) .await .expect("Execution failed"); println!("Embedding:"); println!("{:?}", embedding); } ================================================ FILE: core/examples/example_utils.rs ================================================ use orch::lm::{ AnthropicBuilder, LanguageModel, LanguageModelBuilder, LanguageModelProvider, OllamaBuilder, OpenAiBuilder, }; pub fn get_lm() -> (Box, LanguageModelProvider) { let args = std::env::args().collect::>(); let provider_name = args.get(1).unwrap_or_else(|| { eprintln!("ERROR: Please provide a provider name"); std::process::exit(1); }); let provider = LanguageModelProvider::try_from(provider_name.as_str()) .expect("Invalid provider name. Supported values: 'ollama', 'openai', 'anthropic'"); let open_ai_api_key = { if provider == LanguageModelProvider::OpenAi { std::env::var("OPENAI_API_KEY") .unwrap_or_else(|_| panic!("OPENAI_API_KEY environment variable not set")) } else { String::new() } }; let anthropic_api_key = { if provider == LanguageModelProvider::Anthropic { std::env::var("ANTHROPIC_API_KEY") .unwrap_or_else(|_| panic!("ANTHROPIC_API_KEY environment variable not set")) } else { String::new() } }; let lm: Box = match provider { LanguageModelProvider::Ollama => Box::new(OllamaBuilder::new().try_build().unwrap()), LanguageModelProvider::OpenAi => Box::new( OpenAiBuilder::new() .with_api_key(open_ai_api_key) .try_build() .unwrap(), ), LanguageModelProvider::Anthropic => Box::new( AnthropicBuilder::new() .with_api_key(anthropic_api_key) .try_build() .unwrap(), ), }; (lm, provider) } #[allow(dead_code)] fn main() {} ================================================ FILE: core/examples/structured_data_generation_blog.rs ================================================ //! This example demonstrates how to use the `Executor` to generate a structured response from the LLM. //! Run like so: `cargo run --example structured_data_generation_blog -- blog.md` #![allow(dead_code)] use orch::alignment::AlignmentStrategyBuilder; use orch::execution::*; use orch::response::*; mod example_utils; use example_utils::get_lm; #[derive(Variants, serde::Deserialize)] #[serde(tag = "response_type")] pub enum ResponseVariants { Answer(AnswerResponseVariant), Fail(FailResponseVariant), } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Answer", scenario = "You have reviewed the blog post", description = "Suggestions for improving the blog post" )] pub struct AnswerResponseVariant { #[schema( description = "Suggestions for improving the blog post", example = "[\"You wrote 'excellent' in two consecutive paragraphs in section 'Introduction'\"]" )] pub suggestions: Vec, } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Fail", scenario = "For some reason you failed to generate suggestions", description = "Reason why you failed to generate suggestions" )] pub struct FailResponseVariant { #[schema( description = "Reason why you failed to generate suggestions", example = "Content was invalid" )] pub reason: String, } #[tokio::main] async fn main() { let (lm, _) = get_lm(); // In this example, we use the same LLM for the correction as for the main task. // This could be replaced by a smaller LM. let (corrector_lm, _) = get_lm(); // We define an alignment strategy that uses the correction model. let alignment_strategy = AlignmentStrategyBuilder::new() .with_retries(2) .with_lm(&*corrector_lm) .try_build() .unwrap(); // Mock blog post let prompt = " This is a blog post about the importance of blogging. # Introduction Blogging is a crucial skill for any writer. It allows you to share your thoughts and ideas with others, and it can help you build a following and establish yourself as an expert in your field. "; let executor = StructuredExecutorBuilder::new() .with_lm(&*lm) .with_preamble(" You are an experienced writer and blog post reviewer who helps users improve their blog posts. You will receive a blog post written in Markdown, and you will need to provide suggestions for improving it. Provide *specific* suggestions for improving the blog post, these can as nitpicky as you want. Consider things such as grammar, spelling, clarity, and conciseness. Even things like mentioning the same phrase too much in one paragraph, etc. The tone should be personal, friendly and professional at the same time. Be very specific and refer to specific sentences, paragraph and sections of the blog post. ") .with_options(Box::new(variants!(ResponseVariants))) .with_alignment(alignment_strategy) .try_build() .unwrap(); let response = executor.execute(prompt).await.expect("Execution failed"); match response.content { ResponseVariants::Answer(answer) => { assert!(!answer.suggestions.is_empty()); println!("Suggestions for improving the blog post:"); for suggestion in answer.suggestions { println!("- {}", suggestion); } } ResponseVariants::Fail(fail) => { eprintln!("Model failed to generate a response: {}", fail.reason); std::process::exit(1); } } } ================================================ FILE: core/examples/structured_data_generation_capital.rs ================================================ //! This example demonstrates how to use the `Executor` to generate a structured response from the LLM. //! Run like so: `cargo run --example structured_data_generation_capital -- France` #![allow(dead_code)] use orch::execution::*; use orch::response::*; mod example_utils; use example_utils::get_lm; #[derive(Variants, serde::Deserialize)] pub enum ResponseVariants { Answer(AnswerResponseVariant), Fail(FailResponseVariant), } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Answer", scenario = "You know the capital city of the country", description = "Capital city of the country" )] pub struct AnswerResponseVariant { #[schema( description = "Capital city of the received country", example = "London" )] pub capital: String, } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Fail", scenario = "You don't know the capital city of the country", description = "Reason why the capital city is not known" )] pub struct FailResponseVariant { #[schema( description = "Reason why the capital city is not known", example = "Country 'foobar' does not exist" )] pub reason: String, } #[tokio::main] async fn main() { let (lm, _) = get_lm(); let country = "France"; let executor = StructuredExecutorBuilder::new() .with_lm(&*lm) .with_preamble(" You are a geography expert who helps users understand the capital city of countries around the world. You will receive a country name, and you will need to provide the capital city of that country. ") .with_options(Box::new(variants!(ResponseVariants))) .try_build() .unwrap(); let response = executor.execute(country).await.expect("Execution failed"); match response.content { ResponseVariants::Answer(answer) => { println!("Capital city of {}: {}", country, answer.capital); assert_eq!(answer.capital, "Paris"); } ResponseVariants::Fail(fail) => { println!("Model failed to generate a response: {}", fail.reason); } } } ================================================ FILE: core/examples/text_generation.rs ================================================ //! This example demonstrates how to use the `Executor` to generate a response from the LLM. //! Run like so: `cargo run --example text_generation` use orch::execution::*; mod example_utils; use example_utils::get_lm; #[tokio::main] async fn main() { let (lm, _) = get_lm(); let prompt = "What is 2+2?"; println!("Prompt: {prompt}"); println!("---"); let executor = TextExecutorBuilder::new() .with_lm(&*lm) .try_build() .unwrap(); let response = executor.execute(prompt).await.expect("Execution failed"); println!("Response:"); println!("{}", response.content); assert!(response.content.contains('4')); } ================================================ FILE: core/examples/text_generation_stream.rs ================================================ //! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM. //! Run like so: `cargo run --example text_generation_stream` use orch::{execution::*, lm::LanguageModelProvider}; use tokio_stream::StreamExt; mod example_utils; use example_utils::get_lm; #[tokio::main] async fn main() { let (lm, provider) = get_lm(); if provider == LanguageModelProvider::Anthropic { println!("Streaming is not currently supported for Anthropic. Skipping example."); return; } let prompt = "What is 2+2?"; println!("Prompt: {prompt}"); println!("---"); let executor = TextExecutorBuilder::new() .with_lm(&*lm) .try_build() .unwrap(); let mut response = executor .execute_stream(prompt) .await .expect("Execution failed"); let mut response_text = String::new(); println!("Response:"); while let Some(chunk) = response.stream.next().await { match chunk { Ok(chunk) => { print!("{chunk}"); response_text.push_str(&chunk); } Err(e) => { println!("Error: {e}"); break; } } } println!(); assert!(!response_text.is_empty()); } ================================================ FILE: core/examples/variants_derive.rs ================================================ //! This example demonstrates how to use the `Variants` derive macro to generate a structured response from the LLM. //! //! Run like so: `cargo run --example variants_derive` use orch::response::*; #[derive(Variants, serde::Deserialize)] pub enum ResponseOptions { Answer(AnswerResponseOption), Fail(FailResponseOption), } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Answer", scenario = "You know the capital city of the country", description = "Capital city of the country" )] pub struct AnswerResponseOption { #[schema( description = "Capital city of the received country", example = "London" )] pub capital: String, #[schema( description = "Country of the received capital city", example = "United Kingdom" )] pub country: String, } #[derive(Variant, serde::Deserialize)] #[variant( variant = "Fail", scenario = "You don't know the capital city of the country", description = "Reason why the capital city is not known" )] pub struct FailResponseOption { #[schema( description = "Reason why the capital city is not known", example = "Country 'foobar' does not exist" )] pub reason: String, } fn main() { let response = r#" { "response_type": "Answer", "capital": "London", "country": "United Kingdom" } "#; let parsed_response = variants!(ResponseOptions).parse(response).unwrap(); match parsed_response { ResponseOptions::Answer(answer_response) => { println!("{}", answer_response.capital); } ResponseOptions::Fail(fail_response) => { println!("{}", fail_response.reason); } } } ================================================ FILE: core/src/alignment/mod.rs ================================================ //! A module containing all logic related to alignment. //! Alignment, in this context, means "aligning" the model's output with the desired output. //! This takes the form of a so-called `[AlignmentStrategy]`, which is a trait that defines how to align the model's output. //! //! This concept has similarities to to traditional "resilience" techniques and libraries, such as .NET's [Polly](https://github.com/App-vNext/Polly), //! which I personally like a lot. mod strategy; mod strategy_builder; pub use strategy::*; pub use strategy_builder::*; ================================================ FILE: core/src/alignment/strategy.rs ================================================ use async_recursion::async_recursion; use orch_response_derive::{variants, Variant, Variants}; use thiserror::Error; use crate::{ execution::{ExecutorError, StructuredExecutor, StructuredExecutorBuilder}, lm::{LanguageModel, LanguageModelError, TextCompleteOptions}, }; #[derive(Debug, Error)] pub enum AlignmentError { #[error("Execution failed: {0}")] ExecutionFailed(String), #[error("Language model error: {0}")] LanguageModelError(#[from] LanguageModelError), #[error("Internal error: {0}")] InternalError(String), #[error("Max retries exceeded ({0} retries)")] MaxRetriesExceeded(usize), } pub struct AlignmentStrategy<'a> { pub(crate) lm: &'a dyn LanguageModel, pub(crate) retries: usize, } #[derive(Variants, Clone, serde::Deserialize)] pub enum AlignmentResponse { ResponseCorrection(ResponseCorrectionResponseVariant), SchemaCorrection(SchemaCorrectionResponseVariant), NoCorrection(NoCorrectionResponseVariant), Fail(FailResponseVariant), } #[derive(Variant, Clone, serde::Deserialize)] #[variant( variant = "ResponseCorrection", scenario = "The response format is correct, but the response content itself is incorrect", description = "A correction and a reason why it is needed" )] pub struct ResponseCorrectionResponseVariant { #[schema( description = "Correction of the phrase", example = "{ \"capital\": \"Paris\" }" )] pub correction: String, #[schema( description = "Short reason why a correction is needed", example = "The capital of France is not London as the original model returned, but Paris" )] pub reason: String, } #[derive(Variant, Clone, serde::Deserialize)] #[variant( variant = "SchemaCorrection", scenario = "The schema of the response is incorrect", description = "Explanation of why the schema is incorrect" )] pub struct SchemaCorrectionResponseVariant { #[schema( description = "Correction of the schema, in natural language", example = "\"'capital' should be a string, not a number'\" or \"The 'capital' field has a typo and starts with an uppercase letter\"" )] pub correction: String, #[schema( description = "Short reason why a correction is needed", example = "The 'capital' field is a number, not a string" )] pub reason: String, } #[derive(Variant, Clone, serde::Deserialize)] #[variant( variant = "NoCorrection", scenario = "No correction needed, the original response satisfies the expected output", description = "Short reason why a correction is not needed" )] pub struct NoCorrectionResponseVariant { #[schema( description = "Short reason why a correction is not needed", example = "The user asked for the capital city of France, and the answer is indeed Paris" )] pub reason: String, } #[derive(Variant, Clone, serde::Deserialize)] #[variant( variant = "Fail", scenario = "You don't know how to verify whether the answer is correct or not. You should only go for this response in extreme cases", description = "Reason why you failed to determine whether the answer is correct or not" )] pub struct FailResponseVariant { #[schema( description = "Reason why you failed to determine whether the answer is correct or not", example = "The question is extremely vague and the model returned something completely unrelated" )] pub reason: String, } impl<'a> AlignmentStrategy<'a> { const PREAMBLE: &'static str = " Your purpose is to receive a response from a language model and make sure (and correct otherwise) whether the response is expected or not. Being \"expected\" means that the response is correct and matches the expected output. You should *not* return the response in the schema of the original message, but instead of the schema that you are requested to provide (the one with the response types 'ResponseCorrection', 'SchemaCorrection' and 'NoCorrection'). "; /// Aligns the response of the language model. /// Tries at least once, and continues according to the [`AlignmentStrategy`] /// (e.g., number of retries). pub async fn align( &self, base_lm: &'a dyn LanguageModel, original_preamble: &str, original_prompt: &str, original_response: &str, ) -> Result { let mut iterated_response = original_response.to_owned(); let mut retry_count = 0; let mut prev_alignment_response = None; loop { let response = self .request_correction( original_preamble, original_prompt, &iterated_response, &prev_alignment_response, ) .await?; let Some(response) = response else { // The response may be `None` if the correction deemed that the previous response should be used. continue; }; match &response { AlignmentResponse::NoCorrection(_) => { // Found no correction, can return the original response. return Ok(iterated_response.to_owned()); } response => { retry_count += 1; if retry_count >= self.retries { return Err(AlignmentError::MaxRetriesExceeded(retry_count)); } if let AlignmentResponse::Fail(_) = response { // Failed - simply try again. continue; } let correction = match response { AlignmentResponse::ResponseCorrection(response_correction) => { response_correction.correction.clone() } AlignmentResponse::SchemaCorrection(schema_correction) => { schema_correction.correction.clone() } _ => unreachable!(), }; let correction_prompt = format!(" {original_preamble} NOTE: You have previously answered this with the following response and was incorrect. Here is the response and the correction, please make sure not to repeat the same mistake: ORIGINAL RESPONSE: {original_response} CORRECTION: {correction} "); let new_base_model_response = base_lm .text_complete( original_prompt, &correction_prompt, TextCompleteOptions::default(), ) .await .map_err(AlignmentError::LanguageModelError)?; prev_alignment_response = Some(response.clone()); iterated_response = new_base_model_response.text; } } } } #[async_recursion] async fn request_correction( &self, original_preamble: &str, original_prompt: &str, original_response: &str, prev_alignment_response: &Option, ) -> Result, AlignmentError> { let mut preamble = format!( " {base_preamble} The model received the original instructions: {original_preamble} And the original prompt: {original_prompt} And the original response: {original_response} REMEMBER: Return a response in the schema you are requested (the one with the response types 'ResponseCorrection', 'SchemaCorrection' and 'NoCorrection'). ", base_preamble = Self::PREAMBLE, ); // If `alignment_response` is `None`, then this was the first attempt and no additional preamble is needed. if let Some(prev_alignment_response) = prev_alignment_response { // TODO: Add context of more tries? preamble.push_str(&format!( " IMPORTANT CONTEXT: Before receiving the previous correction, the model has already responded with the following: {} And received the following corrections: ", original_response )); match prev_alignment_response { AlignmentResponse::ResponseCorrection(response_correction) => { preamble.push_str(&format!( "CORRECTION: The response content was incorrect, this is the correction: {}", response_correction.correction, )); } AlignmentResponse::SchemaCorrection(schema_correction) => { preamble.push_str(&format!( "CORRECTION: The response schema was incorrect for the following reason: {} This is the correction: {} ", schema_correction.correction, schema_correction.reason )); } _ => { // No error (this is unexpected) - return the original response. return Err(AlignmentError::InternalError( "Requested correction with no relevant correction response".to_owned(), )); } } } let executor: StructuredExecutor = StructuredExecutorBuilder::new() .with_lm(self.lm) .with_preamble(&preamble) .with_options(Box::new(variants!(AlignmentResponse))) .try_build() .unwrap(); let response = { let correction_response = executor.execute(original_prompt).await; match correction_response { Ok(response) => Some(response.content), Err(ExecutorError::Parsing(_)) => { // The model failed to parse the response, so we return the original response. return Ok(None); } Err(e) => return Err(AlignmentError::ExecutionFailed(e.to_string())), } }; Ok(response) } } ================================================ FILE: core/src/alignment/strategy_builder.rs ================================================ use thiserror::Error; use crate::lm::LanguageModel; use super::strategy::AlignmentStrategy; /// The default number of retries for the alignment strategy, if not overriden. pub const DEFAULT_RETRIES: usize = 2; #[derive(Debug, Error)] pub enum AlignmentStrategyBuilderError { #[error("{0} is not set")] ConfigurationNotSet(String), } #[derive(Default)] pub struct AlignmentStrategyBuilder<'a> { lm: Option<&'a dyn LanguageModel>, retries: Option, } impl<'a> AlignmentStrategyBuilder<'a> { /// Creates a new `AlignmentStrategyBuilder` instance. pub fn new() -> Self { Self { lm: None, retries: Some(DEFAULT_RETRIES), } } /// Sets the language model to use for the alignment strategy. pub fn with_lm(mut self, lm: &'a dyn LanguageModel) -> Self { self.lm = Some(lm); self } /// Sets the number of retries for the alignment strategy. pub fn with_retries(mut self, retries: usize) -> Self { self.retries = Some(retries); self } /// Builds the alignment strategy. /// May fail with a [`AlignmentStrategyBuilderErrro`] if some required configurations are not set. pub fn try_build(self) -> Result, AlignmentStrategyBuilderError> { let Some(lm) = self.lm else { return Err(AlignmentStrategyBuilderError::ConfigurationNotSet( "Language model".to_string(), )); }; let Some(retries) = self.retries else { return Err(AlignmentStrategyBuilderError::ConfigurationNotSet( "Retries".to_string(), )); }; Ok(AlignmentStrategy { lm, retries }) } } ================================================ FILE: core/src/execution/builder.rs ================================================ use thiserror::Error; #[derive(Debug, Error)] pub enum ExecutorBuilderError { #[error("Internal error: {0}")] InternalError(String), #[error("{0} is not set")] ConfigurationNotSet(String), } ================================================ FILE: core/src/execution/executor.rs ================================================ use std::pin::Pin; use thiserror::Error; use tokio_stream::Stream; use crate::{ alignment::AlignmentError, lm::{LanguageModel, LanguageModelError, OllamaError, TextCompleteOptions}, }; #[derive(Debug, Error)] pub enum ExecutorError { #[error("{0}")] General(LanguageModelError), #[error("{0}")] LanguageModelError(LanguageModelError), #[error("Error when calling Ollama API: {0}")] OllamaApi(String), #[error("Parsing LM response failed: {0}")] Parsing(String), #[error("Alignment error: {0}")] Alignment(AlignmentError), } impl From for ExecutorError { fn from(val: LanguageModelError) -> Self { match val { LanguageModelError::Ollama(OllamaError::Api(e)) => ExecutorError::OllamaApi(e), e => ExecutorError::LanguageModelError(e), } } } pub(crate) trait Executor<'a> { /// Generates a text completion from the LLM (non-streaming). async fn text_complete( &self, prompt: &str, ) -> Result, ExecutorError> { text_complete(self.lm(), prompt, &self.system_prompt()).await } /// System prompt (instructions) for the model. fn system_prompt(&self) -> String; fn lm(&self) -> &'a dyn LanguageModel; } // TODO: Support context for completions (e.g., IDs of past conversations in Ollama). pub struct ExecutorContext; pub struct ExecutorTextCompleteResponse { pub content: T, pub context: ExecutorContext, } pub struct ExecutorTextCompleteStreamResponse { pub stream: Pin> + Send>>, pub context: ExecutorContext, } pub async fn text_complete<'a>( lm: &'a dyn LanguageModel, prompt: &str, system_prompt: &str, ) -> Result, ExecutorError> { let options = TextCompleteOptions { ..Default::default() }; let response = lm .text_complete(prompt, system_prompt, options) .await .map_err(ExecutorError::from)?; Ok(ExecutorTextCompleteResponse { content: response.text, context: ExecutorContext {}, }) } pub(crate) async fn generate_embedding<'a>( lm: &'a dyn LanguageModel, prompt: &str, ) -> Result, ExecutorError> { let response = lm .generate_embedding(prompt) .await .map_err(ExecutorError::from)?; Ok(response) } ================================================ FILE: core/src/execution/mod.rs ================================================ //! A module containing all logic related to LLM execution. //! An [`Executor`] is the terminology for a component which executes an LLM, //! and aligns it appropriately (e.g., error correction). //! //! It is not to be confused with an [`Orchestrator`] which manages the execution of an LLM //! or multiple LLMs towards a task. mod builder; mod executor; mod response; mod structured_executor; mod text_executor; pub use builder::*; pub use executor::*; pub use response::*; pub use structured_executor::*; pub use text_executor::*; ================================================ FILE: core/src/execution/response.rs ================================================ #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ResponseFormat { Text, Json, } impl Default for ResponseFormat { fn default() -> Self { Self::Text } } ================================================ FILE: core/src/execution/structured_executor.rs ================================================ use std::cell::OnceCell; use orch_response::{OrchResponseVariants, ResponseSchemaField}; use crate::{alignment::AlignmentStrategy, lm::LanguageModel}; use super::{ generate_embedding, Executor, ExecutorBuilderError, ExecutorContext, ExecutorError, ExecutorTextCompleteResponse, DEFAULT_PREAMBLE, }; pub struct StructuredExecutor<'a, T> { pub(crate) lm: &'a dyn LanguageModel, pub(crate) preamble: Option<&'a str>, pub(crate) variants: Box>, pub(crate) alignment_strategy: Option>, } impl<'a, T> Executor<'a> for StructuredExecutor<'a, T> { fn lm(&self) -> &'a dyn LanguageModel { self.lm } fn system_prompt(&self) -> String { let cell = OnceCell::new(); cell.get_or_init(|| { let response_options = self.variants.variants(); let all_types = response_options .iter() .map(|option| option.type_name.clone()) .collect::>(); let response_options_text = response_options .iter() .map(|option| { let mut schema_text = String::new(); let mut schema_example = "{".to_string(); let type_field = ResponseSchemaField { // NOTE: This is assumed by [`orch_response_derive`] to be the discriminator field. name: "response_type".to_string(), description: format!( "The type of the response (\"{}\" in this case)", option.type_name ) .to_string(), typ: "string".to_string(), example: all_types.first().unwrap().to_string(), }; for (i, field) in option .schema .iter() .chain(std::iter::once(&type_field)) .enumerate() { schema_text.push_str(&format!( " - `{}` of type {} (description: {})\n\n", field.name, field.typ, field.description )); schema_example .push_str(&format!("\"{}\": \"{}\"", field.name, field.example)); if i < option.schema.len() - 1 { schema_example.push(','); } } schema_example.push('}'); format!( "SCENARIO: {}\nDESCRIPTION: {}\nSCHEMA:\n{}\nEXAMPLE RESPONSE: {}\n\n\n", option.scenario, option.description, schema_text, schema_example ) }) .collect::>() .join("\n"); // Add an optional extra preamble supplied by the user. let preamble = self.preamble.map(|pa| format!("Additional information: {}", pa)).unwrap_or("".to_owned()); let system_prompt = format!( " You will receive a prompt from a user, and will need to response with a JSON object that represents the response. Response *only* with the JSON object, and nothing else. No additional preamble or explanations. Only work with the responses you can reply with. {preamble} You have {choices_len} choices to respond, in a JSON format: {response_options_text} ", preamble = preamble, choices_len = response_options.len(), response_options_text = response_options_text ) .trim() .to_string(); system_prompt }) .clone() } } /// Trait for LLM execution. /// This should be implemented for each LLM text generation use-case, where the system prompt /// changes according to the trait implementations. impl<'a, T> StructuredExecutor<'a, T> { /// Generates a structured response from the LLM (non-streaming). /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. pub async fn execute( &'a self, prompt: &'a str, ) -> Result, ExecutorError> { let mut model_response = self.text_complete(prompt).await?.content; if let Some(alignment_strategy) = &self.alignment_strategy { model_response = alignment_strategy .align( self.lm, self.preamble.unwrap_or(DEFAULT_PREAMBLE), prompt, &model_response, ) .await .map_err(ExecutorError::Alignment)?; } let result = self .variants .parse(&model_response) .map_err(|e| ExecutorError::Parsing(format!("{e}\nResponse: {:?}", model_response)))?; // TODO: Add error correction and handling. Ok(ExecutorTextCompleteResponse { content: result, context: ExecutorContext {}, }) } /// Generates an embedding from the LLM. /// /// # Arguments /// * `prompt` - The item to generate an embedding for. /// /// # Returns /// /// A [Result] containing the embedding or an error if there was a problem. pub async fn generate_embedding(&'a self, prompt: &'a str) -> Result, ExecutorError> { generate_embedding(self.lm, prompt).await } } #[derive(Default)] pub struct StructuredExecutorBuilder<'a, T> { lm: Option<&'a dyn LanguageModel>, preamble: Option<&'a str>, variants: Option>>, alignment_strategy: Option>, } impl<'a, T> StructuredExecutorBuilder<'a, T> { pub fn new() -> Self { Self { lm: None, preamble: None, variants: None, alignment_strategy: None, } } pub fn with_lm(mut self, lm: &'a dyn LanguageModel) -> Self { self.lm = Some(lm); self } pub fn with_options(mut self, options: Box>) -> Self { self.variants = Some(options); self } pub fn with_preamble(mut self, preamble: &'a str) -> Self { self.preamble = Some(preamble); self } pub fn with_alignment(mut self, strategy: AlignmentStrategy<'a>) -> Self { self.alignment_strategy = Some(strategy); self } pub fn try_build(self) -> Result, ExecutorBuilderError> { let Some(lm) = self.lm else { return Err(ExecutorBuilderError::ConfigurationNotSet( "Language model".to_string(), )); }; let Some(response_options) = self.variants else { return Err(ExecutorBuilderError::InternalError( "Response variants are not set".to_string(), )); }; Ok(StructuredExecutor { lm, preamble: self.preamble, variants: response_options, alignment_strategy: self.alignment_strategy, }) } } ================================================ FILE: core/src/execution/text_executor.rs ================================================ use crate::lm::{LanguageModel, TextCompleteStreamOptions}; use super::{ generate_embedding, Executor, ExecutorBuilderError, ExecutorContext, ExecutorError, ExecutorTextCompleteResponse, ExecutorTextCompleteStreamResponse, }; pub const DEFAULT_PREAMBLE: &str = "You are a helpful assistant"; pub struct TextExecutor<'a> { pub(crate) lm: &'a dyn LanguageModel, pub(crate) preamble: Option<&'a str>, } impl<'a> Executor<'a> for TextExecutor<'a> { fn lm(&self) -> &'a dyn LanguageModel { self.lm } fn system_prompt(&self) -> String { self.preamble.unwrap_or(DEFAULT_PREAMBLE).to_owned() } } /// Trait for LLM execution. /// This should be implemented for each LLM text generation use-case, where the system prompt /// changes according to the trait implementations. impl<'a> TextExecutor<'a> { /// Generates a streaming response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. pub async fn execute_stream( &'a self, prompt: &'a str, ) -> Result { let options = TextCompleteStreamOptions { ..Default::default() }; let system_prompt = self.system_prompt(); let response = self .lm .text_complete_stream(prompt, &system_prompt, options) .await .map_err(ExecutorError::General)?; Ok(ExecutorTextCompleteStreamResponse { stream: response.stream, context: ExecutorContext {}, }) } /// Generates a response from the LLM (non-streaming). /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. pub async fn execute( &'a self, prompt: &'a str, ) -> Result, ExecutorError> { self.text_complete(prompt).await } /// Generates an embedding from the LLM. /// /// # Arguments /// * `prompt` - The item to generate an embedding for. /// /// # Returns /// /// A [Result] containing the embedding or an error if there was a problem. pub async fn generate_embedding(&'a self, prompt: &'a str) -> Result, ExecutorError> { generate_embedding(self.lm, prompt).await } } #[derive(Default)] pub struct TextExecutorBuilder<'a> { lm: Option<&'a dyn LanguageModel>, preamble: Option<&'a str>, } impl<'a> TextExecutorBuilder<'a> { pub fn new() -> Self { Self { lm: None, preamble: None, } } pub fn with_lm(mut self, lm: &'a dyn LanguageModel) -> Self { self.lm = Some(lm); self } pub fn with_preamble(mut self, preamble: &'a str) -> Self { self.preamble = Some(preamble); self } pub fn try_build(self) -> Result, ExecutorBuilderError> { let Some(lm) = self.lm else { return Err(ExecutorBuilderError::ConfigurationNotSet( "Language model".to_string(), )); }; Ok(TextExecutor { lm, preamble: self.preamble, }) } } ================================================ FILE: core/src/lib.rs ================================================ #![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))] pub mod alignment; pub mod execution; pub mod lm; mod net; pub mod response; ================================================ FILE: core/src/lm/builder.rs ================================================ use thiserror::Error; use super::LanguageModel; #[derive(Debug, Error)] pub enum LanguageModelBuilderError { #[error("{0} is not set")] ConfigurationNotSet(String), } pub trait LanguageModelBuilder { fn new() -> Self; fn try_build(self) -> Result; } ================================================ FILE: core/src/lm/error.rs ================================================ use thiserror::Error; use super::{AnthropicError, LanguageModelProvider, OllamaError, OpenAiError}; #[derive(Debug, Error)] pub enum LanguageModelProviderError { #[error("Invalid LLM provider: {0}")] InvalidValue(String), } impl std::fmt::Display for LanguageModelProvider { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { LanguageModelProvider::Ollama => write!(f, "ollama"), LanguageModelProvider::OpenAi => write!(f, "openai"), LanguageModelProvider::Anthropic => write!(f, "anthropic"), } } } impl Default for LanguageModelProvider { fn default() -> Self { Self::Ollama } } #[derive(Debug, Error)] pub enum LanguageModelError { #[error("Text generation error: {0}")] TextGeneration(String), #[error("Feature unsupported: {0}")] UnsupportedFeature(String), #[error("Embedding generation error: {0}")] EmbeddingGeneration(String), #[error("Configuration error: {0}")] Configuration(String), #[error("Ollama error: {0}")] Ollama(#[from] OllamaError), #[error("OpenAI error: {0}")] OpenAi(#[from] OpenAiError), #[error("Anthropic error: {0}")] Anthropic(#[from] AnthropicError), } ================================================ FILE: core/src/lm/lm_provider/anthropic/builder.rs ================================================ use thiserror::Error; use crate::lm::{LanguageModelBuilder, LanguageModelBuilderError}; use super::client::config::{DEFAULT_API_ENDPOINT, DEFAULT_MODEL}; use super::Anthropic; #[derive(Debug, Error)] pub enum AnthropicBuilderError { #[error("Configuration error: {0} is not set")] ConfigurationNotSet(String), } /// Builds an [`Anthropic`] instance. pub struct AnthropicBuilder { /// API key for the Anthropic API. Required. api_key: Option, /// Base URL for the Anthropic API. Defaults to [`DEFAULT_BASE_URL`]. api_endpoint: Option, /// Model to use for text completion. Defaults to [`DEFAULT_MODEL`]. model: Option, } impl AnthropicBuilder { /// Overrides the default base URL for the Anthropic API. /// Defaults to [`DEFAULT_API_ENDPOINT`]. pub fn with_api_endpoint(mut self, base_url: String) -> Self { self.api_endpoint = Some(base_url); self } /// Sets the required API key for the Anthropic API. pub fn with_api_key(mut self, api_key: String) -> Self { self.api_key = Some(api_key); self } /// Overrides the default model to use for text completion. /// Defaults to [`DEFAULT_MODEL`]. pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } } impl LanguageModelBuilder for AnthropicBuilder { fn new() -> Self { Self { api_key: None, api_endpoint: Some(DEFAULT_API_ENDPOINT.to_string()), model: Some(DEFAULT_MODEL.to_string()), } } /// Tries to build an [`Anthropic`] instance. May fail if the required configurations are not set. fn try_build(self) -> Result { let Some(api_endpoint) = self.api_endpoint else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "API endpoint".to_string(), )); }; let Some(api_key) = self.api_key else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "API key".to_string(), )); }; let Some(model) = self.model else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Model".to_string(), )); }; Ok(Anthropic { api_key, api_endpoint, model: model.to_owned(), }) } } ================================================ FILE: core/src/lm/lm_provider/anthropic/client/anthropic_client.rs ================================================ #![allow(dead_code)] use thiserror::Error; use crate::lm::lm_provider::anthropic::client::models::AnthropicMessagesApiRequest; use super::{ config::DEFAULT_MAX_TOKENS, models::{ AnthropicMessage, AnthropicMessagesApiMessage, AnthropicMessagesApiResponse, AnthropicMessagesApiResponseSuccess, }, }; #[derive(Debug, Error)] pub(crate) enum AnthropicClientError { #[error("{0}")] InternalError(String), #[error("Configuration '{0}' is not set")] ConfigurationNotSet(String), #[error("Failed to serialize/deserialize: {0}")] Marhsalling(String), #[error("Failed to send or receive request to/from Anthropic API: {0}")] Api(String), } /// A client for interacting with the Anthropic API. pub struct AnthropicClient { pub(crate) api_endpoint: String, pub(crate) api_key: String, } impl AnthropicClient { /// Generates a response from the Anthropic API. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. Use "User:...\n\n" for user messages and "Assistant:...\n\n" for assistant messages. /// * `system_prompt` - The system prompt to use for the generation. /// * `options` - The options for the generation (use [`AnthropicClientTextCompleteOptionsBuilder`] to build a new instance). /// /// # Returns /// A [Result] containing the response from the Anthropic API or an error if there was a problem. pub async fn text_complete( &self, messages: &[AnthropicMessage], system_prompt: &str, options: AnthropicClientTextCompleteOptions, ) -> Result { let messages_api_endpoint = format!("{}/v1/messages", self.api_endpoint); let system_prompt = if system_prompt.is_empty() { None } else { Some(system_prompt.to_string()) }; let messages = messages .iter() .map(Self::construct_message) .collect::>(); let req_body = AnthropicMessagesApiRequest { messages, system_prompt, model: options.model, max_tokens_to_sample: DEFAULT_MAX_TOKENS, stop_sequences: None, temperature: None, top_k: None, }; let http_client = reqwest::Client::new(); let req = http_client .post(messages_api_endpoint) // See Anthropic authentication documentation: https://docs.anthropic.com/en/api/getting-started#authentication .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") .header(reqwest::header::CONTENT_TYPE, "application/json") .body( serde_json::to_string(&req_body) .map_err(|e| AnthropicClientError::Marhsalling(e.to_string()))?, ) .build() .map_err(|e| AnthropicClientError::InternalError(e.to_string()))?; let response = http_client .execute(req) .await .map_err(|e| AnthropicClientError::Api(e.to_string()))?; let response_body_json = response .text() .await .map_err(|e| AnthropicClientError::Api(e.to_string()))? .to_string(); let deserialized_response: AnthropicMessagesApiResponse = serde_json::from_str(&response_body_json).map_err(|e| { AnthropicClientError::Marhsalling(format!( "Failed to parse response: {e} (response: {response_body_json})" )) })?; Ok(match deserialized_response { AnthropicMessagesApiResponse::Success(success_response) => success_response, AnthropicMessagesApiResponse::Error(error_response) => { let error_message = error_response.error.message; return Err(AnthropicClientError::Api(error_message)); } }) } fn construct_message(msg: &AnthropicMessage) -> AnthropicMessagesApiMessage { match msg { AnthropicMessage::User(content) => AnthropicMessagesApiMessage { role: "user".to_string(), content: content.to_string(), }, AnthropicMessage::Assistant(content) => AnthropicMessagesApiMessage { role: "assistant".to_string(), content: content.to_string(), }, } } } /// Options for text completion. #[derive(Debug, Default)] pub struct AnthropicClientTextCompleteOptions { /// See [`AnthropicCompleteApiRequest::model`]. pub model: String, /// See [`AnthropicCompleteApiRequest::max_tokens`]. pub max_tokens_to_sample: usize, /// See [`AnthropicCompleteApiRequest::stop_sequences`]. pub stop_sequences: Option>, /// See [`AnthropicCompleteApiRequest::temperature`]. pub temperature: Option, /// See [`AnthropicCompleteApiRequest::top_k`]. pub top_k: Option, } /// Builds a new [`AnthropicClientTextCompleteOptions`] instance. #[derive(Debug, Default)] pub struct AnthropicClientTextCompleteOptionsBuilder { model: Option, max_tokens: usize, stop_sequences: Option>, temperature: Option, top_k: Option, } impl AnthropicClientTextCompleteOptionsBuilder { pub fn new() -> Self { Self { max_tokens: DEFAULT_MAX_TOKENS, ..Default::default() } } /// Sets the model (required). pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } /// Sets the maximum number of tokens to generate before stopping. /// Defaults to [`DEFAULT_MAX_TOKENS`]. pub fn with_max_tokens(mut self, max_tokens: usize) -> Self { self.max_tokens = max_tokens; self } /// Sets the stop sequences. pub fn with_stop_sequences(mut self, stop_sequences: Vec) -> Self { self.stop_sequences = Some(stop_sequences); self } /// Sets the temperature. pub fn with_temperature(mut self, temperature: f32) -> Self { self.temperature = Some(temperature); self } /// Sets the top k. pub fn with_top_k(mut self, top_k: usize) -> Self { self.top_k = Some(top_k); self } /// Tries to build a [`AnthropicClientTextCompleteOptions`] instance. May fail if the required configurations are not set. pub fn try_build(self) -> Result { let Some(model) = self.model else { return Err(AnthropicClientError::ConfigurationNotSet( "Model".to_string(), )); }; Ok(AnthropicClientTextCompleteOptions { model, max_tokens_to_sample: self.max_tokens, stop_sequences: self.stop_sequences, temperature: self.temperature, top_k: self.top_k, }) } } ================================================ FILE: core/src/lm/lm_provider/anthropic/client/builder.rs ================================================ use thiserror::Error; use super::{anthropic_client::AnthropicClient, config}; #[derive(Debug, Error)] pub enum AnthropicBuilderError { #[error("Configuration error: {0} is not set")] ConfigurationNotSet(String), } /// Builds an [`AnthropicClient`] instance. pub struct AnthropicClientBuilder { api_endpoint: String, api_key: Option, } impl AnthropicClientBuilder { pub fn new() -> Self { Self { api_endpoint: config::DEFAULT_API_ENDPOINT.to_string(), api_key: None, } } /// Sets an override for the Anthropic API endpoint. Defaults to [`config::DEFAULT_API_ENDPOINT`]. pub fn with_api_endpoint(mut self, api_endpoint: String) -> Self { self.api_endpoint = api_endpoint; self } /// Sets the required Anthropic API key. pub fn with_api_key(mut self, api_key: String) -> Self { self.api_key = Some(api_key); self } pub fn try_build(self) -> Result { let Some(api_key) = self.api_key else { return Err(AnthropicBuilderError::ConfigurationNotSet( "API key".to_string(), )); }; Ok(AnthropicClient { api_endpoint: self.api_endpoint, api_key, }) } } ================================================ FILE: core/src/lm/lm_provider/anthropic/client/config.rs ================================================ use crate::lm::anthropic_model; /// Default API endpoint for the Anthropic API. pub const DEFAULT_API_ENDPOINT: &str = "https://api.anthropic.com"; /// Default model to use for text completion. pub const DEFAULT_MODEL: &str = anthropic_model::CLAUDE_3_5_SONNET; /// Default maximum number of tokens to generate before stopping. pub const DEFAULT_MAX_TOKENS: usize = 2048; ================================================ FILE: core/src/lm/lm_provider/anthropic/client/mod.rs ================================================ pub mod anthropic_client; pub mod builder; pub mod config; pub mod models; ================================================ FILE: core/src/lm/lm_provider/anthropic/client/models.rs ================================================ use serde::{Deserialize, Serialize}; /// Request for generating a response from the Anthropic API. /// Referenced from the Anthropic API documentation [here](https://docs.anthropic.com/en/api/complete). #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicMessagesApiRequest { /// The model that will complete your prompt. /// See [`anthropic_model`] for a list of built-in model IDs for convenience. /// /// See [models](https://docs.anthropic.com/en/docs/about-claude/models) for a complete list of models supported by Anthropic. pub model: String, /// Messages to generate a completion for. /// /// See [Anthropic API documentation](https://docs.anthropic.com/en/api/messages) for more information. pub messages: Vec, /// Optional system prompt. #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "system")] pub system_prompt: Option, /// The maximum number of tokens to generate before stopping. /// /// Note that the Anthropic models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. #[serde(rename = "max_tokens")] pub max_tokens_to_sample: usize, /// Sequences that will cause the model to stop generating. /// The Anthropic models stop on "\n\nHuman:", and may include additional built-in stop sequences in the future. /// By providing the stop_sequences parameter, you may include additional strings that will cause the model to stop generating. #[serde(skip_serializing_if = "Option::is_none")] pub stop_sequences: Option>, /// Amount of randomness injected into the response. /// /// Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. /// // Note that even with temperature of 0.0, the results will not be fully deterministic. #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, /// Only sample from the top K options for each subsequent token. /// /// Used to remove "long tail" low probability responses. Learn more technical details [here](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277). /// /// Recommended for advanced use cases only. You usually only need to use temperature. #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, } #[derive(Debug, PartialEq, Eq)] pub enum AnthropicMessage { /// A user message. User(String), /// An assistant message. Assistant(String), } #[derive(Debug)] pub enum AnthropicMessageRole { User, Assistant, } impl std::fmt::Display for AnthropicMessageRole { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AnthropicMessageRole::User => write!(f, "User"), AnthropicMessageRole::Assistant => write!(f, "Assistant"), } } } #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicMessagesApiMessage { /// The role of the message. /// For a user, this will be "user". For an assistant, this will be "assistant". pub role: String, /// The content of the message. pub content: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum AnthropicMessagesApiResponse { Success(AnthropicMessagesApiResponseSuccess), Error(AnthropicApiError), } /// Response from the Anthropic API for generating a response. /// Referenced from the Anthropic API documentation [here](https://docs.anthropic.com/en/api/complete). #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicMessagesApiResponseSuccess { /// Object type. For text completion, this will be "completion". #[serde(rename = "type")] pub typ: String, /// The role of the message. For responses, this would be "assistant". pub role: String, /// The model that generated the response. pub model: String, /// The resulting completion up to and excluding the stop sequences. pub content: Vec, /// The reason that the model stopped generating tokens. /// /// This may be one the following values: /// - "stop_sequence": Reached a stop sequence — either provided by you via the stop_sequences parameter, or a stop sequence built into the model /// - "max_tokens": Exceeded `max_tokens_to_sample` or the model's maximum pub stop_reason: Option, } /// Response from the Anthropic API for the messages API endpoint. #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicMessagesApiResponseSuccessContent { /// Type of the response, for text completion, this will be "text". #[serde(rename = "type")] pub typ: String, /// The content of the response. pub text: String, } /// Response from the Anthropic API which indicates an error. /// Referenced from the Anthropic API documentation [here](https://docs.anthropic.com/en/api/errors#error-shapes). #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicApiError { /// Type of the response, for error responses this will be "error". #[serde(rename = "type")] pub typ: String, /// Error message. pub error: AnthropicApiErrorBody, } #[derive(Debug, Serialize, Deserialize)] pub struct AnthropicApiErrorBody { /// Type of the error (e.g., "not_found_error"). #[serde(rename = "type")] pub typ: String, /// Error message. pub message: String, } ================================================ FILE: core/src/lm/lm_provider/anthropic/lm.rs ================================================ use async_trait::async_trait; use thiserror::Error; use crate::lm::{ LanguageModel, LanguageModelError, LanguageModelProvider, TextCompleteOptions, TextCompleteResponse, TextCompleteStreamOptions, TextCompleteStreamResponse, }; use super::client::{ anthropic_client::AnthropicClientTextCompleteOptionsBuilder, builder::AnthropicClientBuilder, models::{AnthropicMessage, AnthropicMessageRole}, }; #[derive(Debug, Clone)] pub struct Anthropic { pub api_endpoint: String, pub api_key: String, pub model: String, } #[derive(Error, Debug)] pub enum AnthropicError { #[error("Unexpected response from API. Error: {0}")] Api(String), #[error("Configuration error: {0}")] Configuration(String), #[error("Serialization error: {0}")] Serialization(String), #[error( "OpenAi API is not available. Please check if OpenAi is running in the specified port. Error: {0}" )] ApiUnavailable(String), #[error("Invalid input: {0}")] InvalidInput(String), } #[async_trait] impl LanguageModel for Anthropic { // TODO: Support context. async fn text_complete( &self, prompt: &str, system_prompt: &str, _options: TextCompleteOptions, ) -> Result { let client = AnthropicClientBuilder::new() .with_api_endpoint(self.api_endpoint.clone()) .with_api_key(self.api_key.clone()) .try_build() .map_err(|e| { LanguageModelError::Anthropic(AnthropicError::Configuration(e.to_string())) })?; let options = AnthropicClientTextCompleteOptionsBuilder::new() .with_model(self.model.clone()) .try_build() .map_err(|e| { LanguageModelError::Anthropic(AnthropicError::Configuration(e.to_string())) })?; // In the case of Anthropic, we need to supply the full history of the conversation. // We therefore parse the prompt string and construct the messages. let messages = Self::messages_from_prompt(prompt)?; let response = client .text_complete(messages.as_slice(), system_prompt, options) .await .map_err(|e| LanguageModelError::Anthropic(AnthropicError::Api(e.to_string())))?; let response_content = response .content .first() .ok_or(AnthropicError::Api("Response content is empty".to_string()))?; Ok(TextCompleteResponse { text: response_content.text.clone(), context: None, }) } async fn text_complete_stream( &self, _prompt: &str, _system_prompt: &str, _options: TextCompleteStreamOptions, ) -> Result { return Err(LanguageModelError::UnsupportedFeature( "Streaming is not supported for Anthropic".to_string(), )); } async fn generate_embedding(&self, _prompt: &str) -> Result, LanguageModelError> { return Err(LanguageModelError::UnsupportedFeature( "Embedding generation is not available on Anthropic. For more details see https://docs.anthropic.com/en/docs/build-with-claude/embeddings".to_string(), )); } fn provider(&self) -> LanguageModelProvider { LanguageModelProvider::Anthropic } fn text_completion_model_name(&self) -> String { self.model.to_string() } fn embedding_model_name(&self) -> String { "(UNSUPPORTED)".to_string() } } impl Anthropic { fn messages_from_prompt(prompt: &str) -> Result, LanguageModelError> { if !prompt.starts_with("User:") && !prompt.starts_with("Assistant:") { // Assume the prompt is just the user message. return Ok(vec![AnthropicMessage::User(prompt.to_string())]); } let mut messages = Vec::new(); let mut iterated_prompt = prompt.to_owned(); while !iterated_prompt.trim().is_empty() { let current_role = if iterated_prompt.starts_with("User:") { AnthropicMessageRole::User } else { AnthropicMessageRole::Assistant }; let current_message = iterated_prompt .strip_prefix(format!("{}:", current_role).as_str()) .map(|s| s.trim()) .ok_or(AnthropicError::InvalidInput( "Prompt is not in the expected format".to_string(), ))?; if !current_message.contains("\n\n") { // Last message - it contains the entire content. match current_role { AnthropicMessageRole::User => { messages.push(AnthropicMessage::User(current_message.to_owned())); } AnthropicMessageRole::Assistant => { messages.push(AnthropicMessage::Assistant(current_message.to_owned())); } } break; } // Parse until the next role. let (current_message, next_message) = current_message .split_once("\n\n") .ok_or(AnthropicError::InvalidInput( "Prompt is not in the expected format".to_string(), ))?; match current_role { AnthropicMessageRole::User => { messages.push(AnthropicMessage::User(current_message.to_owned())); } AnthropicMessageRole::Assistant => { messages.push(AnthropicMessage::Assistant(current_message.to_owned())); } } iterated_prompt = next_message.trim().to_string(); } Ok(messages) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_messages_from_prompt_single_message() { let prompt = "Hello"; let messages = Anthropic::messages_from_prompt(prompt).unwrap(); assert_eq!(messages.len(), 1); let Some(msg) = messages.first() else { panic!("Expected at least one message"); }; let AnthropicMessage::User(content) = msg else { panic!("Expected a user message"); }; assert_eq!(content, "Hello"); } #[test] fn test_messages_from_prompt_multiple_messages() { let prompt = "User: Hello\n\nAssistant: Hi\n\nUser: How are you?"; let messages = Anthropic::messages_from_prompt(prompt).unwrap(); assert_eq!(messages.len(), 3); assert_eq!(messages[0], AnthropicMessage::User("Hello".to_string())); assert_eq!(messages[1], AnthropicMessage::Assistant("Hi".to_string())); assert_eq!( messages[2], AnthropicMessage::User("How are you?".to_string()) ); } } ================================================ FILE: core/src/lm/lm_provider/anthropic/mod.rs ================================================ mod builder; mod client; mod lm; mod models; pub use builder::*; pub use lm::*; pub use models::*; ================================================ FILE: core/src/lm/lm_provider/anthropic/models.rs ================================================ #![allow(dead_code)] /// Convenience constants for the Anthropic models. pub mod anthropic_model { pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-20240620"; pub const CLAUDE_3_OPUS: &str = "claude-3-opus-20240229"; pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229"; pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307"; } ================================================ FILE: core/src/lm/lm_provider/mod.rs ================================================ mod anthropic; mod models; mod ollama; mod openai; pub use anthropic::*; pub use models::*; pub use ollama::*; pub use openai::*; ================================================ FILE: core/src/lm/lm_provider/models.rs ================================================ use serde::{Deserialize, Serialize}; use crate::lm::{LanguageModel, LanguageModelProviderError}; use super::{Anthropic, Ollama, OpenAi}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum LanguageModelProvider { #[serde(rename = "ollama")] Ollama, #[serde(rename = "openai")] OpenAi, #[serde(rename = "anthropic")] Anthropic, } impl TryFrom<&str> for LanguageModelProvider { type Error = LanguageModelProviderError; fn try_from(value: &str) -> Result { Ok(match value { "ollama" => LanguageModelProvider::Ollama, "openai" => LanguageModelProvider::OpenAi, "anthropic" => LanguageModelProvider::Anthropic, _ => return Err(LanguageModelProviderError::InvalidValue(value.to_string())), }) } } impl LanguageModelProvider { /// Returns whether the provider runs local inference or not. pub fn is_local(&self) -> bool { match self { LanguageModelProvider::Ollama => true, LanguageModelProvider::OpenAi => false, LanguageModelProvider::Anthropic => false, } } } pub enum OrchLanguageModel { Ollama(Ollama), OpenAi(OpenAi), Anthropic(Anthropic), } impl OrchLanguageModel { pub fn provider(&self) -> LanguageModelProvider { match self { OrchLanguageModel::Ollama(_) => LanguageModelProvider::Ollama, OrchLanguageModel::OpenAi(_) => LanguageModelProvider::OpenAi, OrchLanguageModel::Anthropic(_) => LanguageModelProvider::Anthropic, } } pub fn text_completion_model_name(&self) -> String { match self { OrchLanguageModel::Ollama(lm) => lm.text_completion_model_name(), OrchLanguageModel::OpenAi(lm) => lm.text_completion_model_name(), OrchLanguageModel::Anthropic(lm) => lm.text_completion_model_name(), } } pub fn embedding_model_name(&self) -> String { match self { OrchLanguageModel::Ollama(lm) => lm.embedding_model_name(), OrchLanguageModel::OpenAi(lm) => lm.embedding_model_name(), OrchLanguageModel::Anthropic(lm) => lm.embedding_model_name(), } } } ================================================ FILE: core/src/lm/lm_provider/ollama/builder.rs ================================================ use thiserror::Error; use crate::lm::{LanguageModelBuilder, LanguageModelBuilderError}; use super::config::{DEFAULT_BASE_URL, DEFAULT_EMBEDDINGS_MODEL, DEFAULT_MODEL}; use super::Ollama; #[derive(Debug, Error)] pub enum OllamaBuilderError { #[error("Configuration error: {0} is not set")] ConfigurationNotSet(String), } /// Builds an [`Ollama`] instance. pub struct OllamaBuilder { /// Base URL for the Ollama API. Defaults to [`DEFAULT_BASE_URL`]. base_url: Option, /// Model to use for text completion. Defaults to [`DEFAULT_MODEL`]. model: Option, /// Model to use for embedding generation. Defaults to [`DEFAULT_EMBEDDINGS_MODEL`]. embeddings_model: Option, } impl OllamaBuilder { /// Overrides the default base URL for the Ollama API. /// Defaults to [`DEFAULT_BASE_URL`]. pub fn with_base_url(mut self, base_url: String) -> Self { self.base_url = Some(base_url); self } /// Overrides the default model to use for text completion. /// Defaults to [`DEFAULT_MODEL`]. pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } pub fn with_embeddings_model(mut self, embeddings_model: String) -> Self { self.embeddings_model = Some(embeddings_model); self } } impl LanguageModelBuilder for OllamaBuilder { fn new() -> Self { Self { base_url: Some(DEFAULT_BASE_URL.to_string()), model: Some(DEFAULT_MODEL.to_string()), embeddings_model: Some(DEFAULT_EMBEDDINGS_MODEL.to_string()), } } /// Tries to build an [`Ollama`] instance. May fail if the required configurations are not set. fn try_build(self) -> Result { let Some(base_url) = self.base_url else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Base URL".to_string(), )); }; let Some(model) = self.model else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Model".to_string(), )); }; let Some(embeddings_model) = self.embeddings_model else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Embeddings model".to_string(), )); }; Ok(Ollama { base_url: base_url.to_owned(), model: model.to_owned(), embeddings_model: embeddings_model.to_owned(), }) } } ================================================ FILE: core/src/lm/lm_provider/ollama/config.rs ================================================ use super::{ollama_embedding_model, ollama_model}; /// Default base URL for the Ollama API. pub const DEFAULT_BASE_URL: &str = "http://localhost:11434"; /// Default model to use for text completion. pub const DEFAULT_MODEL: &str = ollama_model::LLAMA3_1_8B; /// Default model to use for embedding generation. pub const DEFAULT_EMBEDDINGS_MODEL: &str = ollama_embedding_model::NOMIC_EMBED_TEXT; ================================================ FILE: core/src/lm/lm_provider/ollama/lm.rs ================================================ use async_trait::async_trait; use lm::{ error::LanguageModelError, models::{ TextCompleteOptions, TextCompleteResponse, TextCompleteStreamOptions, TextCompleteStreamResponse, }, LanguageModel, LanguageModelProvider, }; use net::SseClient; use thiserror::Error; use tokio_stream::StreamExt; use crate::*; use super::{ OllamaApiModelsMetadata, OllamaEmbeddingsRequest, OllamaEmbeddingsResponse, OllamaGenerateRequest, OllamaGenerateResponse, OllamaGenerateStreamItemResponse, }; #[derive(Debug, Clone)] pub struct Ollama { pub base_url: String, pub model: String, pub embeddings_model: String, } #[derive(Error, Debug)] pub enum OllamaError { #[error("Unexpected response from API. Error: {0}")] Api(String), #[error("Unexpected error when parsing response from Ollama. Error: {0}")] Parsing(String), #[error("{0}")] Configuration(String), #[error("{0}")] Serialization(String), #[error( "Ollama API is not available. Please check if Ollama is running in the specified port. Error: {0}" )] ApiUnavailable(String), } impl Ollama { /// Lists the running models in the Ollama API. /// /// # Returns /// /// A [Result] containing the list of running models or an error if there was a problem. /// #[allow(dead_code)] pub(crate) fn list_running_models(&self) -> Result { let response = self.get_from_ollama_api("api/ps")?; let parsed_response = Self::parse_models_response(&response)?; Ok(parsed_response) } // /// Lists the local models in the Ollama API. // /// // /// # Returns // /// // /// A [Result] containing the list of local models or an error if there was a problem. #[allow(dead_code)] pub fn list_local_models(&self) -> Result { let response = self.get_from_ollama_api("api/tags")?; let parsed_response = Self::parse_models_response(&response)?; Ok(parsed_response) } fn parse_models_response(response: &str) -> Result { let models: OllamaApiModelsMetadata = serde_json::from_str(response).map_err(|e| OllamaError::Parsing(e.to_string()))?; Ok(models) } fn get_from_ollama_api(&self, url: &str) -> Result { let url = format!("{}/{}", self.base_url, url); let client = reqwest::blocking::Client::new(); let response = client .get(url) .send() .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; let response_text = response .text() .map_err(|e| OllamaError::Api(e.to_string()))?; Ok(response_text) } } #[async_trait] impl LanguageModel for Ollama { async fn text_complete( &self, prompt: &str, system_prompt: &str, _options: TextCompleteOptions, ) -> Result { let body = OllamaGenerateRequest { model: self.model.to_owned(), prompt: prompt.to_string(), system: Some(system_prompt.to_string()), ..Default::default() }; let client = reqwest::Client::new(); let url = format!("{}/api/generate", self.base_url); let response = client .post(url) .body(serde_json::to_string(&body).unwrap()) .send() .await .map_err(|e| LanguageModelError::Ollama(OllamaError::ApiUnavailable(e.to_string())))?; let body = response .text() .await .map_err(|e| LanguageModelError::Ollama(OllamaError::Api(e.to_string())))?; let ollama_response: OllamaGenerateResponse = serde_json::from_str(&body).map_err(|e| { LanguageModelError::Ollama(OllamaError::Parsing(format!( "{}. Received response: {body}", e ))) })?; match ollama_response { OllamaGenerateResponse::Success(success_response) => Ok(TextCompleteResponse { text: success_response.response, context: success_response.context, }), OllamaGenerateResponse::Error(error_response) => Err(LanguageModelError::Ollama( OllamaError::Api(format!("{error_response:?}")), )), } } async fn text_complete_stream( &self, prompt: &str, system_prompt: &str, options: TextCompleteStreamOptions, ) -> Result { let body = OllamaGenerateRequest { model: self.model.to_owned(), prompt: prompt.to_string(), stream: Some(true), format: None, images: None, system: Some(system_prompt.to_string()), keep_alive: Some("5m".to_string()), context: options.context, }; let url = format!("{}/api/generate", self.base_url); let stream = SseClient::post(&url, Some(serde_json::to_string(&body).unwrap())); let stream = stream.map(|event| { let parsed_message = serde_json::from_str::(&event); match parsed_message { Ok(message) => match message { OllamaGenerateStreamItemResponse::Success(success_response) => { Ok(success_response.response) } OllamaGenerateStreamItemResponse::Error(error_response) => Err( LanguageModelError::Ollama(OllamaError::Api(format!("{error_response:?}"))), ), }, Err(e) => Err(LanguageModelError::Ollama(OllamaError::Parsing( e.to_string(), ))), } }); let response = TextCompleteStreamResponse { stream: Box::pin(stream), }; Ok(response) } async fn generate_embedding(&self, prompt: &str) -> Result, LanguageModelError> { let client = reqwest::Client::new(); let url = format!("{}/api/embeddings", self.base_url); let body = OllamaEmbeddingsRequest { model: self.embeddings_model.to_owned(), prompt: prompt.to_string(), }; let response = client .post(url) .body( serde_json::to_string(&body) .map_err(|e| OllamaError::Serialization(e.to_string()))?, ) .send() .await .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; let body = response .text() .await .map_err(|e| OllamaError::Api(e.to_string()))?; let response: OllamaEmbeddingsResponse = serde_json::from_str(&body).map_err(|e| OllamaError::Parsing(e.to_string()))?; Ok(response.embedding) } fn provider(&self) -> LanguageModelProvider { LanguageModelProvider::Ollama } fn text_completion_model_name(&self) -> String { self.model.to_string() } fn embedding_model_name(&self) -> String { self.embeddings_model.to_string() } } ================================================ FILE: core/src/lm/lm_provider/ollama/mod.rs ================================================ mod builder; mod config; mod lm; mod models; pub use builder::*; pub use lm::*; pub use models::*; ================================================ FILE: core/src/lm/lm_provider/ollama/models.rs ================================================ use serde::{Deserialize, Serialize}; pub mod ollama_model { /// https://ollama.com/library/llama3:latest pub const LLAMA3: &str = "llama3:latest"; /// https://ollama.com/library/llama3:8b pub const LLAMA3_8B: &str = "llama3:8b"; /// https://ollama.com/library/llama3.1:8b pub const LLAMA3_1_8B: &str = "llama3.1:8b"; /// https://ollama.com/library/phi3:latest pub const PHI3_MINI: &str = "phi3:latest"; /// https://ollama.com/library/codestral:latest pub const CODESTRAL: &str = "codestral:latest"; } pub mod ollama_embedding_model { /// https://ollama.com/library/nomic-embed-text:latest pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text:latest"; } /// Response from the Ollama API for obtaining information about local models. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#list-running-models). #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelsMetadata { pub models: Vec, } /// Response item from the Ollama API for obtaining information about local models. /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelMetadata { /// The name of the model (e.g., "mistral:latest") pub name: String, /// The Ollama identifier of the model (e.g., "mistral:latest") pub model: String, /// Size of the model in bytes pub size: usize, /// Digest of the model using SHA256 (e.g., "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8") pub digest: String, /// Model expiry time in ISO 8601 format (e.g., "2024-06-04T14:38:31.83753-07:00") pub expires_at: Option, /// More details about the model pub details: OllamaApiModelDetails, } /// Details about a running model in the API for listing running models (`GET /api/ps`). /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelDetails { /// Model identifier that this model is based on pub parent_model: String, /// Format that this model is stored in (e.g., "gguf") pub format: String, /// Model family (e.g., "ollama") pub family: String, /// Parameters of the model (e.g., "7.2B") pub parameter_size: String, /// Quantization level of the model (e.g., "Q4_0" for 4-bit quantization) pub quantization_level: String, } /// Request for generating a response from the Ollama API. /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-a-completion). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateRequest { /// Model identifier (e.g., "mistral:latest") pub model: String, /// The prompt to generate a response for (e.g., "List all Kubernetes pods") pub prompt: String, /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory pub context: Option>, /// Optional list of base64-encoded images (for multimodal models such as `llava`) pub images: Option>, /// Optional format to use for the response (currently only "json" is supported) pub format: Option, /// Optional flag that controls whether the response is streamed or not (defaults to true). /// If `false`` the response will be returned as a single response object, rather than a stream of objects pub stream: Option, // System message (overrides what is defined in the Modelfile) pub system: Option, /// Controls how long the model will stay loaded into memory following the request (default: 5m) pub keep_alive: Option, } impl Default for OllamaGenerateRequest { fn default() -> Self { Self { model: ollama_model::CODESTRAL.to_string(), prompt: "".to_string(), stream: Some(false), format: None, images: None, system: Some("You are a helpful assistant".to_string()), keep_alive: Some("5m".to_string()), context: None, } } } /// Response from the Ollama API for generating a response. #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum OllamaGenerateResponse { Success(OllamaGenerateResponseSuccess), Error(OllamaGenerateResponseError), } #[derive(Debug, Serialize, Deserialize)] #[allow(dead_code)] pub struct OllamaGenerateResponseSuccess { /// Model identifier (e.g., "mistral:latest") pub model: String, /// Time at which the response was generated (ISO 8601 format) pub created_at: String, /// The response to the prompt pub response: String, /// The encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory pub context: Option>, /// The duration of the response in nanoseconds pub total_duration: usize, } #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateResponseError { /// Error message. pub error: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum OllamaGenerateStreamItemResponse { Success(OllamaGenerateStreamItemResponseSuccess), Error(OllamaGenerateStreamItemResponseError), } #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateStreamItemResponseSuccess { /// Model identifier (e.g., "mistral:latest") pub model: String, /// Time at which the response was generated (ISO 8601 format) pub created_at: String, /// The response to the prompt pub response: String, } #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateStreamItemResponseError { /// Error message. pub error: String, } /// Request for generating an embedding from the Ollama API. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). /// #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize, Default)] pub struct OllamaEmbeddingsRequest { /// The string to generate an embedding for. pub prompt: String, /// The model to use for the embedding generation. pub model: String, } /// Response from the Ollama API for generating an embedding. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). /// #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaEmbeddingsResponse { /// The embedding for the prompt. pub embedding: Vec, } ================================================ FILE: core/src/lm/lm_provider/openai/builder.rs ================================================ use thiserror::Error; use crate::lm::{lm_provider::openai::config, LanguageModelBuilder, LanguageModelBuilderError}; use super::OpenAi; #[derive(Debug, Error)] pub enum OpenAiBuilderError { #[error("Configuration error: {0} is not set")] ConfigurationNotSet(String), } /// Builds an [`OpenAi`] instance. pub struct OpenAiBuilder { api_endpoint: Option, api_key: Option, model: Option, embeddings_model: Option, } impl OpenAiBuilder { /// Sets the required API key for the OpenAI API. pub fn with_api_key(mut self, api_key: String) -> Self { self.api_key = Some(api_key); self } /// Overrides the default API endpoint for the OpenAI API. pub fn with_api_endpoint(mut self, api_endpoint: String) -> Self { self.api_endpoint = Some(api_endpoint); self } /// Sets the model to use for text completion. Defaults to [`config::DEFAULT_MODEL`]. pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } /// Sets the model to use for embedding generation. Defaults to [`config::DEFAULT_EMBEDDINGS_MODEL`]. pub fn with_embeddings_model(mut self, embeddings_model: String) -> Self { self.embeddings_model = Some(embeddings_model.clone()); self } } impl LanguageModelBuilder for OpenAiBuilder { fn new() -> Self { Self { api_key: None, api_endpoint: None, model: Some(config::DEFAULT_MODEL.to_string()), embeddings_model: Some(config::DEFAULT_EMBEDDINGS_MODEL.to_string()), } } /// Tries to build an [`OpenAi`] instance. May fail if the required configurations are not set. fn try_build(self) -> Result { let Some(api_key) = self.api_key else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "API key".to_string(), )); }; let Some(model) = self.model else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Model".to_string(), )); }; let Some(embeddings_model) = self.embeddings_model else { return Err(LanguageModelBuilderError::ConfigurationNotSet( "Embeddings model".to_string(), )); }; Ok(OpenAi { api_endpoint: self.api_endpoint, api_key: api_key.to_owned(), model: model.to_owned(), embeddings_model: embeddings_model.to_owned(), }) } } ================================================ FILE: core/src/lm/lm_provider/openai/config.rs ================================================ use super::{openai_embedding_model, openai_model}; /// Default model to use for text completion. pub const DEFAULT_MODEL: &str = openai_model::GPT_4O_MINI; /// Default model to use for embedding generation. pub const DEFAULT_EMBEDDINGS_MODEL: &str = openai_embedding_model::TEXT_EMBEDDING_ADA_002; ================================================ FILE: core/src/lm/lm_provider/openai/lm.rs ================================================ use async_trait::async_trait; use lm::{ error::LanguageModelError, models::{ TextCompleteOptions, TextCompleteResponse, TextCompleteStreamOptions, TextCompleteStreamResponse, }, LanguageModel, LanguageModelProvider, }; use openai_api_rs::v1::{ api::OpenAIClient, chat_completion::{self, ChatCompletionRequest}, embedding::EmbeddingRequest, }; use thiserror::Error; use tokio_stream::{self as stream}; use crate::*; #[derive(Debug, Clone)] pub struct OpenAi { pub api_endpoint: Option, pub api_key: String, pub model: String, pub embeddings_model: String, } #[derive(Error, Debug)] pub enum OpenAiError { #[error("Unexpected response from API. Error: {0}")] Api(String), #[error("Configuration error: {0}")] Configuration(String), #[error("Serialization error: {0}")] Serialization(String), #[error("OpenAI API is not available. Error: {0}")] ApiUnavailable(String), } #[async_trait] impl LanguageModel for OpenAi { async fn text_complete( &self, prompt: &str, system_prompt: &str, _options: TextCompleteOptions, ) -> Result { let mut client = OpenAIClient::new(self.api_key.to_owned()); if let Some(api_endpoint) = self.api_endpoint.clone() { client.api_endpoint = api_endpoint; } let messages = vec![ chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::system, content: chat_completion::Content::Text(system_prompt.to_owned()), name: None, tool_calls: None, tool_call_id: None, }, chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(prompt.to_owned()), name: None, tool_calls: None, tool_call_id: None, }, ]; // TODO: Support customization of max tokens and temperature. let req = ChatCompletionRequest::new(self.model.to_owned(), messages); let result = client .chat_completion(req) .await .map_err(|e| LanguageModelError::OpenAi(OpenAiError::Api(e.to_string())))?; let completion = result .choices .first() .unwrap() .message .content .clone() .unwrap(); Ok(TextCompleteResponse { text: completion, // TODO: Support context. context: None, }) } async fn text_complete_stream( &self, prompt: &str, system_prompt: &str, _options: TextCompleteStreamOptions, ) -> Result { // TODO: Support streaming - currently it just sends a single message. let text_completion_response = self .text_complete(prompt, system_prompt, TextCompleteOptions { context: None }) .await?; Ok(TextCompleteStreamResponse { stream: Box::pin(stream::once(Ok(text_completion_response.text))), }) } async fn generate_embedding(&self, prompt: &str) -> Result, LanguageModelError> { let client = OpenAIClient::new(self.api_key.to_owned()); let resp = client .embedding(EmbeddingRequest { model: self.embeddings_model.to_owned(), input: prompt.to_owned(), dimensions: None, user: None, }) .await .map_err(|e| LanguageModelError::OpenAi(OpenAiError::Api(e.to_string())))?; let data = resp.data.first().expect("Embedding data not found"); Ok(data.embedding.clone()) } fn provider(&self) -> LanguageModelProvider { LanguageModelProvider::OpenAi } fn text_completion_model_name(&self) -> String { self.model.to_string() } fn embedding_model_name(&self) -> String { self.embeddings_model.to_string() } } ================================================ FILE: core/src/lm/lm_provider/openai/mod.rs ================================================ mod builder; mod config; mod lm; mod models; pub use builder::*; pub use lm::*; pub use models::*; ================================================ FILE: core/src/lm/lm_provider/openai/models.rs ================================================ pub mod openai_model { pub const GPT_3_5_TURBO: &str = "gpt-3.5-turbo"; pub const GPT_4: &str = "gpt-4"; pub const GPT_4O_TURBO: &str = "gpt-4o-turbo"; pub const GPT_4O_MINI: &str = "gpt-4o-mini"; } pub mod openai_embedding_model { pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; pub const TEXT_EMBEDDING_ADA_002_DIMENSIONS: usize = 1536; pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; pub const TEXT_EMBEDDING_3_SMALL_DIMENSIONS: usize = 1536; pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; pub const TEXT_EMBEDDING_3_LARGE_DIMENSIONS: usize = 3072; } ================================================ FILE: core/src/lm/mod.rs ================================================ //! A module containing all logic related to LMs (Language Models). //! This don't strictly have to be *large* language models (i.e., SLMs such as Phi-3 or Mistral NeMo are included). mod builder; mod error; mod lm_provider; mod models; pub use builder::*; pub use error::*; pub use lm_provider::*; pub use models::*; ================================================ FILE: core/src/lm/models.rs ================================================ #![allow(dead_code)] use std::pin::Pin; use async_trait::async_trait; use dyn_clone::DynClone; use tokio_stream::Stream; use super::{error::LanguageModelError, LanguageModelProvider}; /// A trait for language model providers which implements text completion, embeddings, etc. /// /// > `DynClone` is used so that there can be dynamic dispatch of the `Llm` trait, /// > especially needed for [magic-cli](https://github.com/guywaldman/magic-cli). #[async_trait] pub trait LanguageModel: DynClone + Send + Sync { /// Generates a response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// * `options` - The options for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. /// async fn text_complete( &self, prompt: &str, system_prompt: &str, options: TextCompleteOptions, ) -> Result; /// Generates a streaming response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// * `options` - The options for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. /// async fn text_complete_stream( &self, prompt: &str, system_prompt: &str, options: TextCompleteStreamOptions, ) -> Result; /// Generates an embedding from the LLM. /// /// # Arguments /// * `prompt` - The item to generate an embedding for. /// /// # Returns /// /// A [Result] containing the embedding or an error if there was a problem. async fn generate_embedding(&self, prompt: &str) -> Result, LanguageModelError>; /// Returns the provider of the LLM. fn provider(&self) -> LanguageModelProvider; /// Returns the name of the model used for text completions. fn text_completion_model_name(&self) -> String; /// Returns the name of the model used for embeddings. fn embedding_model_name(&self) -> String; } #[derive(Debug, Clone, Default)] pub struct TextCompleteOptions { /// An encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory. /// This should be as returned from the previous response. pub context: Option>, } #[derive(Debug, Clone, Default)] pub struct TextCompleteStreamOptions { pub context: Option>, } #[derive(Debug, Clone)] pub struct TextCompleteResponse { pub text: String, // TODO: This is specific to Ollama, context looks differently for other LLM providers. pub context: Option>, } pub struct TextCompleteStreamResponse { pub stream: Pin> + Send>>, // TODO: Handle context with streaming response. // pub context: Vec, } ================================================ FILE: core/src/net/mod.rs ================================================ /// Module for working with Server-Sent Events. mod sse; pub use sse::*; ================================================ FILE: core/src/net/sse.rs ================================================ use async_gen::AsyncIter; use reqwest::{header, Client}; use tokio_stream::Stream; /// A client for working with Server-Sent Events. pub struct SseClient; impl SseClient { pub fn post(url: &str, body: Option) -> impl Stream { let client = Client::new(); let mut req = Client::post(&client, url) .header(header::ACCEPT, "text/event-stream") .header(header::CACHE_CONTROL, "no-cache") .header(header::CONNECTION, "keep-alive") .header(header::CONTENT_TYPE, "application/json"); if let Some(body) = body { req = req.body(body); } let req = req.build().unwrap(); AsyncIter::from(async_gen::gen! { let mut conn = client.execute(req).await.unwrap(); while let Some(event) = conn.chunk().await.unwrap() { yield std::str::from_utf8(&event).unwrap().to_owned(); } }) } } ================================================ FILE: core/src/response.rs ================================================ pub use orch_response::*; pub use orch_response_derive::*; ================================================ FILE: orch.code-workspace ================================================ { "folders": [ { "path": "core" }, { "path": "response" }, { "path": "response_derive" }, { "path": ".", "name": "(root)" } ], "settings": {} } ================================================ FILE: response/.gitignore ================================================ ================================================ FILE: response/Cargo.toml ================================================ [package] name = "orch_response" version = "0.0.16" edition = "2021" license = "MIT" description = "Models for orch Executor responses" homepage = "https://github.com/guywaldman/orch" repository = "https://github.com/guywaldman/orch" keywords = ["llm", "openai", "ollama", "rust"] [dependencies] dyn-clone = "1.0.17" serde = "1.0.204" serde_json = "1.0.120" ================================================ FILE: response/src/lib.rs ================================================ use dyn_clone::DynClone; /// Represents an option for the response of a language model. #[derive(Debug, Clone)] pub struct ResponseOption { /// The discriminator for the response type (e.g., `Answer` or `Fail`). pub type_name: String, /// The scenario for the response (e.g., "You know the answer" or "You don't know the answer"). pub scenario: String, /// The description of what the response represents (e.g., "The capital city of the received country" or "Explanation on why the capital city is not known"). pub description: String, /// The schema for the response. pub schema: Vec, } /// Represents a field in the schema of a response. #[derive(Debug, Clone)] pub struct ResponseSchemaField { /// Name of the field (e.g., "capital" for the capital city). pub name: String, /// Description of the field (e.g., "Capital city of the received country"). pub description: String, /// Type of the field (e.g., "string" for a string). pub typ: String, /// Example of the field (e.g., "London" for the capital city). pub example: String, } pub trait OrchResponseVariant: Send + Sync { fn variant() -> ResponseOption; } pub trait OrchResponseVariants: DynClone + Send + Sync { fn variants(&self) -> Vec; fn parse(&self, response: &str) -> Result; } ================================================ FILE: response_derive/.rustfmt.toml ================================================ max_width = 140 ================================================ FILE: response_derive/Cargo.toml ================================================ [package] name = "orch_response_derive" version = "0.0.16" edition = "2021" license = "MIT" description = "Derive macros for orch Executor responses" homepage = "https://github.com/guywaldman/orch" repository = "https://github.com/guywaldman/orch" keywords = ["llm", "openai", "ollama", "rust"] [lib] proc-macro = true [dependencies] orch_response = { path = "../response", version = "0.0.16" } darling = "0.20.10" proc-macro2 = "1.0.86" quote = "1.0.36" syn = "2.0.71" [dev-dependencies] serde = { version = "1.0.204", features = ["derive"] } serde_json = "1.0.120" ================================================ FILE: response_derive/README.md ================================================ # Procedural Macros for Orch ## Resources - https://www.freecodecamp.org/news/procedural-macros-in-rust/ ================================================ FILE: response_derive/src/attribute_impl.rs ================================================ use darling::FromMeta; /// #[variant(...)] #[derive(Debug, FromMeta)] pub(crate) struct VariantAttribute { pub(crate) variant: String, pub(crate) scenario: String, pub(crate) description: String, } /// #[schema(...)] #[derive(Debug, FromMeta)] pub(crate) struct SchemaAttribute { pub(crate) description: String, pub(crate) example: String, } ================================================ FILE: response_derive/src/derive_impl.rs ================================================ use darling::FromMeta; use quote::quote; use syn::{parse_macro_input, spanned::Spanned, DeriveInput, PathArguments}; use crate::attribute_impl::{SchemaAttribute, VariantAttribute}; pub(crate) fn response_variants_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut output = quote!(); // Bring traits into scope. output.extend(quote! { use ::orch_response::OrchResponseVariant; use ::serde::de::Error; }); let original_enum = parse_macro_input!(input as DeriveInput); let DeriveInput { data, ident, .. } = original_enum.clone(); let syn::Data::Enum(data) = data else { panic!("#[derive(OrchResponseVariants)] can only be used with enums"); }; let original_enum_ident = ident; let vec_capacity = data.variants.len(); let mut options_vec_pushes = quote!(); for variant in data.variants.iter() { let ident = syn::Ident::new( &get_enum_variant_struct_ident(variant).expect("Failed to parse enum variant"), variant.ident.span(), ); options_vec_pushes.extend(quote! { options.push(#ident::variant()); }); } // We construct a new struct that will be used to parse the response. // NOTE: This is hacky, but a workaround for the fact that the enum cannot be constructed. let derived_enum_struct_ident = syn::Ident::new(&format!("{}Derived", original_enum_ident), original_enum_ident.span()); output.extend(quote! { #[derive(::std::fmt::Debug, ::std::clone::Clone)] pub struct #derived_enum_struct_ident; }); // Note: We parse with a dynamic evaluation and looking at the `response_type` field, but this could be done // by deriving #[serde(tag = "response_type")] on the enum. let mut response_type_arms = quote!(); for variant in data.variants.iter() { let variant_ident = variant.ident.clone(); let variant_ident_str = syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); let struct_ident = syn::Ident::new( &get_enum_variant_struct_ident(variant).expect("Failed to parse enum variant"), variant.ident.span(), ); response_type_arms.extend(quote! { #variant_ident_str => Ok(#original_enum_ident::#variant_ident(serde_json::from_str::<#struct_ident>(response)?)), }); } output.extend(quote! { impl ::orch_response::OrchResponseVariants<#original_enum_ident> for #derived_enum_struct_ident { fn variants(&self) -> Vec<::orch_response::ResponseOption> { let mut options = Vec::with_capacity(#vec_capacity); #options_vec_pushes options } fn parse(&self, response: &str) -> Result<#original_enum_ident, ::serde_json::Error> { let dynamic_parsed = serde_json::from_str::(response)?; let response_type = dynamic_parsed.get("response_type"); let response_type = match response_type { Some(response_type) => match response_type.as_str() { Some(response_type) => response_type, None => { return Err(::serde_json::Error::custom(format!( "Invalid response type: {}", response ))); } } None => { return Err(::serde_json::Error::custom(format!( "Invalid response type: {}", response ))); } }; match response_type { #response_type_arms _ => Err(::serde_json::Error::custom("Invalid response type")), } } } }); output.into() } pub fn response_variant_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let original_struct = parse_macro_input!(input as DeriveInput); let DeriveInput { data, ident, attrs, .. } = original_struct.clone(); let syn::Data::Struct(data) = data else { panic!("#[derive(OrchResponseOption)] can only be used with structs"); }; let original_struct_ident = ident.clone(); let fields = data.fields; // Parse the #[variant(...)] attribute. let variant_attr = attrs .iter() .filter_map(|attr| VariantAttribute::from_meta(&attr.meta).ok()) .next() .expect("#[variant(...)] attribute not found on variant field"); let VariantAttribute { variant, scenario, description, } = variant_attr; // Parse the fields used in [`orch::response::OrchResponseVariant`]. let mut schema_fields = Vec::new(); for variant_field in fields.iter() { // Parse the #[schema(...)] attribute. let schema_attr = variant_field .attrs .iter() .filter_map(|attr| SchemaAttribute::from_meta(&attr.meta).ok()) .collect::>(); if schema_attr.len() != 1 { panic!("Expected a single #[schema(...)] attribute for each field of the enum variant with the correct format and parameters"); } let SchemaAttribute { description, example } = schema_attr.first().expect("Failed to parse schema attribute"); let typ = ast_type_to_str(&variant_field.ty).unwrap_or_else(|_| { panic!( "Failed to convert type to string for field `{}` of variant `{}`", variant_field.ident.as_ref().unwrap(), ident ) }); let typ = syn::LitStr::new(&typ, variant_field.span()); let field_ident = syn::LitStr::new(&variant_field.ident.as_ref().unwrap().to_string(), variant_field.span()); schema_fields.push(quote! { ::orch_response::ResponseSchemaField { name: #field_ident.to_string(), description: #description.to_string(), typ: #typ.to_string(), example: #example.to_string(), } }) } quote! { impl ::orch_response::OrchResponseVariant for #original_struct_ident { fn variant() -> ::orch_response::ResponseOption { ::orch_response::ResponseOption { type_name: #variant.to_string(), scenario: #scenario.to_string(), description: #description.to_string(), schema: vec![ #(#schema_fields),* ] } } } } .into() } // Parse `Answer(AnswerResponseOption)` into `AnswerResponseOption`. fn get_enum_variant_struct_ident(variant: &syn::Variant) -> Result { // We expect the enum variant to look like this: `Answer(AnswerResponseOption)`, // so we parse the `AnswerResponseOption` struct. let syn::Fields::Unnamed(fields) = &variant.fields else { panic!("Expected an unnamed struct for each enum variant"); }; let Some(syn::Field { ty, .. }) = fields.unnamed.first() else { panic!("Expected an unnamed struct for each enum variant"); }; let syn::Type::Path(p) = ty else { panic!("Expected an unnamed struct for each enum variant"); }; let ident = &p.path.segments.first().unwrap().ident; Ok(ident.to_string()) } fn ast_type_to_str(ty: &syn::Type) -> Result { match ty { syn::Type::Path(tp) => { let ps = tp.path.segments.first(); let Some(first_path_segment) = ps else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let t = first_path_segment.ident.to_string(); match t.as_ref() { "String" => { // SUPPORTED: String Ok("string".to_owned()) } "bool" => { // SUPPORTED: bool Ok("boolean".to_owned()) } "Option" => { let PathArguments::AngleBracketed(ab) = &tp.path.segments.first().unwrap().arguments else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let syn::GenericArgument::Type(t) = ab.args.first().unwrap() else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let syn::Type::Path(p) = t else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let t = p.path.segments.first().unwrap().ident.to_string(); match t.as_ref() { "String" => { // SUPPORTED: Option Ok("string?".to_owned()) } "bool" => { // SUPPORTED: Option Ok("boolean?".to_owned()) } _ => Err(format!("Unsupported/unexpected type: {}", t).to_owned()), } } "Vec" => { let PathArguments::AngleBracketed(ab) = &tp.path.segments.first().unwrap().arguments else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let syn::GenericArgument::Type(t) = ab.args.first().unwrap() else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let syn::Type::Path(p) = t else { return Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()); }; let t = p.path.segments.first().unwrap().ident.to_string(); match t.as_ref() { // SUPPORTED: Vec "String" => Ok("string[]".to_owned()), _ => Err(format!("Unsupported/unexpected type: {}", t).to_owned()), } } _ => Err(format!("Unsupported/unexpected type: {}", t).to_owned()), } } _ => Err(format!("Unsupported/unexpected type: {:?}", ty).to_owned()), } } ================================================ FILE: response_derive/src/lib.rs ================================================ mod attribute_impl; mod derive_impl; use quote::quote; /// Used to derive the `OrchResponseVariants` trait for a given enum #[proc_macro_derive(Variants)] pub fn derive_orch_response_variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream { derive_impl::response_variants_derive(input) } /// Used to derive the `OrchResponseVariant` trait for a given enum. #[proc_macro_derive(Variant, attributes(variant, schema))] pub fn derive_orch_response_variant_variant(input: proc_macro::TokenStream) -> proc_macro::TokenStream { derive_impl::response_variant_derive(input) } /// Used to construct the identifier of the derived enum. #[proc_macro] pub fn variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream { // Expects the identifier of the derived enum. let enum_ident = syn::parse_macro_input!(input as syn::Ident); let derived_enum_ident = syn::Ident::new(&format!("{}Derived", enum_ident), enum_ident.span()); quote! { #derived_enum_ident {} } .into() } ================================================ FILE: scripts/ci.sh ================================================ #!/bin/bash set -e for i in 1 2 3; do systemctl is-active ollama.service && sudo systemctl stop ollama.service curl -fsSL https://ollama.com/install.sh | sh sleep 5 if systemctl is-active ollama.service; then break fi done ollama serve & ollama pull phi3:mini ollama pull nomic-embed-text:latest ================================================ FILE: scripts/examples.sh ================================================ #!/usr/bin/env bash SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "$SCRIPT_DIR/utils.sh" provider="$1" # If provider is not provided, or is not one of the supported providers, exit with an error if [[ "$provider" != "ollama" && "$provider" != "openai" && "$provider" != "anthropic" ]]; then echo "ERROR: Invalid provider. Supported providers: ollama, openai, anthropic" exit 1 fi FAILURE=0 info "Running all examples in the 'orch' crate for provider $provider..." pushd core 2>&1 >/dev/null for example in $(find examples -name '*.rs'); do example=${example%.rs} info "Running example: $(basename $example)" cargo run --quiet --example $(basename $example) -- $provider 1>/dev/null if [ $? -ne 0 ]; then FAILURE=1 error "Example $(basename $example) failed" fi success "Example $(basename $example) succeeded" done popd 2>&1 >/dev/null info "Ran all examples in the 'orch' crate for provider $provider" exit $FAILURE ================================================ FILE: scripts/utils.sh ================================================ #!/usr/bin/env bash if [ "${CI:-false}" != "true" ]; then bold=$(tput bold) normal=$(tput sgr0) light_grey=$(tput setaf 250) blue=$(tput setaf 4) green=$(tput setaf 2) yellow=$(tput setaf 3) red=$(tput setaf 1) else bold="" normal="" blue="" light_grey="" green="" yellow="" red="" fi function formatted_time { date +%FT%T.%3N } function formatted_severity { printf "%+6s:" $1 } function formatted_log { log_severity=$1 log_message="$2" echo "${bold}$(formatted_time) $(formatted_severity $log_severity) $log_message ${normal}" } function info { formatted_log INFO "$1" } function warn { formatted_log WARN "$1" } function success { echo "${green}$(formatted_log INFO "$1")" } function error { echo >&2 "${red}$(formatted_log ERROR "$1")" } function error-and-exit { error "$1" if [ "${2:''}" != "" ]; then exit $2 else exit 1 fi } ================================================ FILE: src/core/mod.rs ================================================ mod net; pub use net::*; ================================================ FILE: src/core/net/mod.rs ================================================ /// Module for working with Server-Sent Events. mod sse; pub use sse::*; ================================================ FILE: src/core/net/sse.rs ================================================ use async_gen::AsyncIter; use reqwest::{header, Client}; use tokio_stream::Stream; /// A client for working with Server-Sent Events. pub struct SseClient; impl SseClient { pub fn post(url: &str, body: Option) -> impl Stream { let client = Client::new(); let mut req = Client::post(&client, url) .header(header::ACCEPT, "text/event-stream") .header(header::CACHE_CONTROL, "no-cache") .header(header::CONNECTION, "keep-alive") .header(header::CONTENT_TYPE, "application/json"); if let Some(body) = body { req = req.body(body); } let req = req.build().unwrap(); AsyncIter::from(async_gen::gen! { let mut conn = client.execute(req).await.unwrap(); while let Some(event) = conn.chunk().await.unwrap() { yield std::str::from_utf8(&event).unwrap().to_owned(); } }) } } ================================================ FILE: src/executor.rs ================================================ use std::pin::Pin; use thiserror::Error; use tokio_stream::Stream; use crate::{Llm, LlmError, TextCompleteOptions, TextCompleteStreamOptions}; pub struct Executor<'a, L: Llm> { llm: &'a L, } #[derive(Debug, Error)] pub enum ExecutorError { #[error("LLM error: {0}")] Llm(LlmError), } impl<'a, L: Llm> Executor<'a, L> { /// Creates a new `Executor` instance. /// /// # Arguments /// * `llm` - The LLM to use for the execution. pub fn new(llm: &'a L) -> Self { Self { llm } } /// Generates a response from the LLM (non-streaming). /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. pub async fn text_complete( &self, prompt: &str, system_prompt: &str, ) -> Result { let options = TextCompleteOptions { ..Default::default() }; let response = self .llm .text_complete(prompt, system_prompt, options) .await .map_err(ExecutorError::Llm)?; Ok(ExecutorTextCompleteResponse { text: response.text, context: ExecutorContext {}, }) } /// Generates a streaming response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. pub async fn text_complete_stream( &self, prompt: &str, system_prompt: &str, ) -> Result { let options = TextCompleteStreamOptions { ..Default::default() }; let response = self .llm .text_complete_stream(prompt, system_prompt, options) .await .map_err(ExecutorError::Llm)?; Ok(ExecutorTextCompleteStreamResponse { stream: response.stream, context: ExecutorContext {}, }) } /// Generates an embedding from the LLM. /// /// # Arguments /// * `prompt` - The item to generate an embedding for. /// /// # Returns /// /// A [Result] containing the embedding or an error if there was a problem. pub async fn generate_embedding(&self, prompt: &str) -> Result, ExecutorError> { let response = self .llm .generate_embedding(prompt) .await .map_err(ExecutorError::Llm)?; Ok(response) } } // TODO: Support context for completions (e.g., IDs of past conversations in Ollama). pub struct ExecutorContext; pub struct ExecutorTextCompleteResponse { pub text: String, pub context: ExecutorContext, } pub struct ExecutorTextCompleteStreamResponse { pub stream: Pin> + Send>>, pub context: ExecutorContext, } ================================================ FILE: src/lib.rs ================================================ mod core; mod executor; mod llm; // TODO: Narrow the scope of the use statements. pub use core::*; pub use executor::*; pub use llm::*; ================================================ FILE: src/llm/error.rs ================================================ use thiserror::Error; use crate::{LlmProvider, OllamaError}; #[derive(Debug, Error)] pub enum LlmProviderError { #[error("Invalid LLM provider: {0}")] InvalidValue(String), } impl std::fmt::Display for LlmProvider { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { LlmProvider::Ollama => write!(f, "ollama"), LlmProvider::OpenAi => write!(f, "openai"), } } } impl Default for LlmProvider { fn default() -> Self { Self::Ollama } } impl TryFrom<&str> for LlmProvider { type Error = LlmProviderError; fn try_from(value: &str) -> Result { match value { "ollama" => Ok(LlmProvider::Ollama), "openai" => Ok(LlmProvider::OpenAi), _ => Err(LlmProviderError::InvalidValue(value.to_string())), } } } #[derive(Debug, Error)] pub enum LlmError { #[error("Text generation error: {0}")] TextGeneration(String), #[error("Embedding generation error: {0}")] EmbeddingGeneration(String), #[error("Configuration error: {0}")] Configuration(String), #[error("Ollama error: {0}")] Ollama(#[from] OllamaError), } ================================================ FILE: src/llm/llm_provider/mod.rs ================================================ mod ollama; mod openai; pub use ollama::*; ================================================ FILE: src/llm/llm_provider/ollama/config.rs ================================================ use serde::{Deserialize, Serialize}; /// Default base URL for the Ollama API. pub const DEFAULT_BASE_URL: &str = "http://localhost:11434"; /// Default model for text completion. pub const DEFAULT_MODEL: &str = ollama_model::; /// Default model for embeddings. pub const DEFAULT_EMBEDDING_MODEL: &str = "nomic-embed-text:latest"; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaConfig { pub base_url: Option, pub model: Option, pub embedding_model: Option, } impl Default for OllamaConfig { fn default() -> Self { Self { base_url: Some(DEFAULT_BASE_URL.to_string()), model: Some("codestral:latest".to_string()), embedding_model: Some("nomic-embed-text:latest".to_string()), } } } ================================================ FILE: src/llm/llm_provider/ollama/llm.rs ================================================ use thiserror::Error; use tokio_stream::StreamExt; use crate::*; pub mod ollama_model { pub const CODESTRAL: &str = "codestral:latest"; } pub mod ollama_embedding_model { pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text:latest"; } #[derive(Debug, Clone)] pub struct Ollama<'a> { base_url: &'a str, pub model: Option<&'a str>, pub embeddings_model: Option<&'a str>, } impl Default for Ollama<'_> { fn default() -> Self { Self { base_url: "http://localhost:11434", model: Some(ollama_model::CODESTRAL), embeddings_model: Some(ollama_embedding_model::NOMIC_EMBED_TEXT), } } } pub struct OllamaBuilder<'a> { base_url: &'a str, model: Option<&'a str>, embeddings_model: Option<&'a str>, } impl Default for OllamaBuilder<'_> { fn default() -> Self { let ollama = Ollama::default(); Self { base_url: ollama.base_url, model: ollama.model, embeddings_model: ollama.embeddings_model, } } } impl<'a> OllamaBuilder<'a> { pub fn new() -> Self { Default::default() } pub fn with_base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } pub fn with_model(mut self, model: &'a str) -> Self { self.model = Some(model); self } pub fn with_embeddings_model(mut self, embeddings_model: &'a str) -> Self { self.embeddings_model = Some(embeddings_model); self } pub fn build(self) -> Ollama<'a> { Ollama { base_url: self.base_url, model: self.model, embeddings_model: self.embeddings_model, } } } #[derive(Error, Debug)] pub enum OllamaError { #[error("Unexpected response from API. Error: {0}")] Api(String), #[error("Unexpected error when parsing response from Ollama. Error: {0}")] Parsing(String), #[error("Configuration error: {0}")] Configuration(String), #[error("Serialization error: {0}")] Serialization(String), #[error( "Ollama API is not available. Please check if Ollama is running in the specified port. Error: {0}" )] ApiUnavailable(String), } impl<'a> Ollama<'a> { /// Lists the running models in the Ollama API. /// /// # Returns /// /// A [Result] containing the list of running models or an error if there was a problem. /// #[allow(dead_code)] pub(crate) fn list_running_models(&self) -> Result { let response = self.get_from_ollama_api("api/ps")?; let parsed_response = Self::parse_models_response(&response)?; Ok(parsed_response) } // /// Lists the local models in the Ollama API. // /// // /// # Returns // /// // /// A [Result] containing the list of local models or an error if there was a problem. #[allow(dead_code)] pub fn list_local_models(&self) -> Result { let response = self.get_from_ollama_api("api/tags")?; let parsed_response = Self::parse_models_response(&response)?; Ok(parsed_response) } fn parse_models_response(response: &str) -> Result { let models: OllamaApiModelsMetadata = serde_json::from_str(response).map_err(|e| OllamaError::Parsing(e.to_string()))?; Ok(models) } fn get_from_ollama_api(&self, url: &str) -> Result { let url = format!("{}/{}", self.base_url()?, url); let client = reqwest::blocking::Client::new(); let response = client .get(url) .send() .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; let response_text = response .text() .map_err(|e| OllamaError::Api(e.to_string()))?; Ok(response_text) } fn base_url(&self) -> Result { Ok(self.base_url.to_string()) } fn model(&self) -> Result { self.model .map(|s| s.to_owned()) .ok_or_else(|| OllamaError::Configuration("Model not set".to_string())) } fn embedding_model(&self) -> Result { self.embeddings_model .map(|s| s.to_owned()) .ok_or_else(|| OllamaError::Configuration("Embedding model not set".to_string())) } } impl<'a> Llm for Ollama<'a> { async fn text_complete( &self, prompt: &str, system_prompt: &str, _options: TextCompleteOptions, ) -> Result { let body = OllamaGenerateRequest { model: self .model() .map_err(|_e| LlmError::Configuration("Model not set".to_string()))?, prompt: prompt.to_string(), system: Some(system_prompt.to_string()), ..Default::default() }; let client = reqwest::Client::new(); let url = format!( "{}/api/generate", self.base_url() .map_err(|_e| LlmError::Configuration("Base URL not set".to_string()))? ); let response = client .post(url) .body(serde_json::to_string(&body).unwrap()) .send() .await .map_err(|e| LlmError::Ollama(OllamaError::ApiUnavailable(e.to_string())))?; let body = response .text() .await .map_err(|e| LlmError::Ollama(OllamaError::Api(e.to_string())))?; let ollama_response: OllamaGenerateResponse = serde_json::from_str(&body) .map_err(|e| LlmError::Ollama(OllamaError::Parsing(e.to_string())))?; let response = TextCompleteResponse { text: ollama_response.response, context: ollama_response.context, }; Ok(response) } async fn text_complete_stream( &self, prompt: &str, system_prompt: &str, options: TextCompleteStreamOptions, ) -> Result { let body = OllamaGenerateRequest { model: self.model()?, prompt: prompt.to_string(), stream: Some(true), format: None, images: None, system: Some(system_prompt.to_string()), keep_alive: Some("5m".to_string()), context: options.context, }; let url = format!("{}/api/generate", self.base_url()?); let stream = SseClient::post(&url, Some(serde_json::to_string(&body).unwrap())); let stream = stream.map(|event| { let parsed_message = serde_json::from_str::(&event); match parsed_message { Ok(message) => Ok(message.response), Err(e) => Err(LlmError::Ollama(OllamaError::Parsing(e.to_string()))), } }); let response = TextCompleteStreamResponse { stream: Box::pin(stream), }; Ok(response) } async fn generate_embedding(&self, prompt: &str) -> Result, LlmError> { let client = reqwest::Client::new(); let url = format!("{}/api/embeddings", self.base_url()?); let body = OllamaEmbeddingsRequest { model: self.embedding_model()?, prompt: prompt.to_string(), }; let response = client .post(url) .body( serde_json::to_string(&body) .map_err(|e| OllamaError::Serialization(e.to_string()))?, ) .send() .await .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; let body = response .text() .await .map_err(|e| OllamaError::Api(e.to_string()))?; let response: OllamaEmbeddingsResponse = serde_json::from_str(&body).map_err(|e| OllamaError::Parsing(e.to_string()))?; Ok(response.embedding) } fn provider(&self) -> LlmProvider { LlmProvider::Ollama } fn text_completion_model_name(&self) -> String { self.model().expect("Model not set").to_string() } fn embedding_model_name(&self) -> String { self.embedding_model() .expect("Embedding model not set") .to_string() } } ================================================ FILE: src/llm/llm_provider/ollama/mod.rs ================================================ mod config; mod llm; mod models; pub use llm::*; pub use models::*; ================================================ FILE: src/llm/llm_provider/ollama/models.rs ================================================ use serde::{Deserialize, Serialize}; use crate::ollama_model; /// Response from the Ollama API for obtaining information about local models. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#list-running-models). #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelsMetadata { pub models: Vec, } /// Response item from the Ollama API for obtaining information about local models. /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelMetadata { /// The name of the model (e.g., "mistral:latest") pub name: String, /// The Ollama identifier of the model (e.g., "mistral:latest") pub model: String, /// Size of the model in bytes pub size: usize, /// Digest of the model using SHA256 (e.g., "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8") pub digest: String, /// Model expiry time in ISO 8601 format (e.g., "2024-06-04T14:38:31.83753-07:00") pub expires_at: Option, /// More details about the model pub details: OllamaApiModelDetails, } /// Details about a running model in the API for listing running models (`GET /api/ps`). /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaApiModelDetails { /// Model identifier that this model is based on pub parent_model: String, /// Format that this model is stored in (e.g., "gguf") pub format: String, /// Model family (e.g., "ollama") pub family: String, /// Parameters of the model (e.g., "7.2B") pub parameter_size: String, /// Quantization level of the model (e.g., "Q4_0" for 4-bit quantization) pub quantization_level: String, } /// Request for generating a response from the Ollama API. /// /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-a-completion). #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateRequest { /// Model identifier (e.g., "mistral:latest") pub model: String, /// The prompt to generate a response for (e.g., "List all Kubernetes pods") pub prompt: String, /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory pub context: Option>, /// Optional list of base64-encoded images (for multimodal models such as `llava`) pub images: Option>, /// Optional format to use for the response (currently only "json" is supported) pub format: Option, /// Optional flag that controls whether the response is streamed or not (defaults to true). /// If `false`` the response will be returned as a single response object, rather than a stream of objects pub stream: Option, // System message (overrides what is defined in the Modelfile) pub system: Option, /// Controls how long the model will stay loaded into memory following the request (default: 5m) pub keep_alive: Option, } impl Default for OllamaGenerateRequest { fn default() -> Self { Self { model: ollama_model::CODESTRAL.to_string(), prompt: "".to_string(), stream: Some(false), format: None, images: None, system: Some("You are a helpful assistant".to_string()), keep_alive: Some("5m".to_string()), context: None, } } } #[derive(Debug, Serialize, Deserialize)] #[allow(dead_code)] pub struct OllamaGenerateResponse { /// Model identifier (e.g., "mistral:latest") pub model: String, /// Time at which the response was generated (ISO 8601 format) pub created_at: String, /// The response to the prompt pub response: String, /// The encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory pub context: Option>, /// The duration of the response in nanoseconds pub total_duration: usize, } #[derive(Debug, Serialize, Deserialize)] pub struct OllamaGenerateStreamItemResponse { /// Model identifier (e.g., "mistral:latest") pub model: String, /// Time at which the response was generated (ISO 8601 format) pub created_at: String, /// The response to the prompt pub response: String, } /// Request for generating an embedding from the Ollama API. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). /// #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize, Default)] pub struct OllamaEmbeddingsRequest { /// The string to generate an embedding for. pub prompt: String, /// The model to use for the embedding generation. pub model: String, } /// Response from the Ollama API for generating an embedding. /// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). /// #[allow(dead_code)] #[derive(Debug, Serialize, Deserialize)] pub struct OllamaEmbeddingsResponse { /// The embedding for the prompt. pub embedding: Vec, } ================================================ FILE: src/llm/llm_provider/openai.rs ================================================ // use async_trait::async_trait; // use openai_api_rs::v1::{ // api::OpenAIClient, // chat_completion::{self, ChatCompletionRequest}, // common::{GPT3_5_TURBO, GPT4, GPT4_O}, // }; // pub mod openai_model { // pub const GPT35_TURBO: &str = GPT35_TURBO; // pub const GPT4: &str = GPT4; // pub const GPT40: &str = GPT40; // } // pub struct OpenAi<'a> { // pub model: &'a str, // api_key: &'a str, // } // impl<'a> OpenAi<'a> { // pub fn new(api_key: &'a str, model: &'a str) -> Self { // Self { api_key, model } // } // } // #[async_trait] // impl<'a> TextCompletionLlm for OpenAi<'a> { // async fn complete( // &self, // system_prompts: &[String], // ) -> Result> { // let client = OpenAIClient::new(self.api_key.to_owned()); // let system_msgs = system_prompts // .iter() // .map(|p| chat_completion::ChatCompletionMessage { // role: chat_completion::MessageRole::system, // content: chat_completion::Content::Text(p.to_owned()), // name: None, // tool_calls: None, // tool_call_id: None, // }) // .collect::>(); // let mut req = ChatCompletionRequest::new(self.model.to_owned(), system_msgs); // req.max_tokens = Some(self.config.max_tokens as i64); // req.temperature = Some(self.config.temperature); // let result = client.chat_completion(req).await?; // let completion = result // .choices // .first() // .unwrap() // .message // .content // .clone() // .unwrap(); // Ok(completion) // } // } ================================================ FILE: src/llm/mod.rs ================================================ mod error; mod llm_provider; mod models; pub use error::*; pub use llm_provider::*; pub use models::*; ================================================ FILE: src/llm/models.rs ================================================ #![allow(dead_code)] use std::pin::Pin; use dyn_clone::DynClone; use serde::{Deserialize, Serialize}; use tokio_stream::Stream; use super::error::LlmError; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum LlmProvider { #[serde(rename = "ollama")] Ollama, #[serde(rename = "openai")] OpenAi, } /// A trait for LLM providers which implements text completion, embeddings, etc. /// /// > `DynClone` is used so that there can be dynamic dispatch of the `Llm` trait, /// > especially needed for [magic-cli](https://github.com/guywaldman/magic-cli). pub trait Llm: DynClone { /// Generates a response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// * `options` - The options for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. /// fn text_complete( &self, prompt: &str, system_prompt: &str, options: TextCompleteOptions, ) -> impl std::future::Future> + Send; /// Generates a streaming response from the LLM. /// /// # Arguments /// * `prompt` - The prompt to generate a response for. /// * `system_prompt` - The system prompt to use for the generation. /// * `options` - The options for the generation. /// /// # Returns /// A [Result] containing the response from the LLM or an error if there was a problem. /// fn text_complete_stream( &self, prompt: &str, system_prompt: &str, options: TextCompleteStreamOptions, ) -> impl std::future::Future> + Send; /// Generates an embedding from the LLM. /// /// # Arguments /// * `prompt` - The item to generate an embedding for. /// /// # Returns /// /// A [Result] containing the embedding or an error if there was a problem. fn generate_embedding( &self, prompt: &str, ) -> impl std::future::Future, LlmError>> + Send; /// Returns the provider of the LLM. fn provider(&self) -> LlmProvider; /// Returns the name of the model used for text completions. fn text_completion_model_name(&self) -> String; /// Returns the name of the model used for embeddings. fn embedding_model_name(&self) -> String; } #[derive(Debug, Clone, Default)] pub struct TextCompleteOptions { /// An encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory. /// This should be as returned from the previous response. pub context: Option>, } #[derive(Debug, Clone, Default)] pub struct TextCompleteStreamOptions { pub context: Option>, } #[derive(Debug, Clone)] pub struct TextCompleteResponse { pub text: String, // TODO: This is specific to Ollama, context looks differently for other LLM providers. pub context: Option>, } pub struct TextCompleteStreamResponse { pub stream: Pin> + Send>>, // TODO: Handle context with streaming response. // pub context: Vec, } #[derive(Debug)] pub(crate) struct SystemPromptResponseOption { pub scenario: String, pub type_name: String, pub response: String, pub schema: Vec, } #[derive(Debug)] pub(crate) struct SystemPromptCommandSchemaField { pub name: String, pub description: String, pub typ: String, pub example: String, }