Repository: cfahlgren1/observers Branch: main Commit: ae4b5a985691 Files: 47 Total size: 93.4 KB Directory structure: gitextract_2kjfhakd/ ├── .github/ │ └── workflows/ │ ├── black.yml │ └── python_tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── README.md ├── examples/ │ ├── models/ │ │ ├── aisuite_example.py │ │ ├── async_openai_example.py │ │ ├── hf_client_example.py │ │ ├── litellm_example.py │ │ ├── ollama_example.py │ │ ├── openai_example.py │ │ ├── stream_async_hf_client_example.py │ │ ├── stream_openai_example.py │ │ └── transformers_example.py │ ├── openai_function_calling_example.py │ ├── stores/ │ │ ├── argilla_example.py │ │ ├── datasets_example.py │ │ ├── duckdb_example.py │ │ └── opentelemetry_example.py │ └── vision_example.py ├── pyproject.toml ├── src/ │ └── observers/ │ ├── __init__.py │ ├── base.py │ ├── frameworks/ │ │ └── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ ├── aisuite.py │ │ ├── base.py │ │ ├── hf_client.py │ │ ├── litellm.py │ │ ├── openai.py │ │ └── transformers.py │ └── stores/ │ ├── __init__.py │ ├── argilla.py │ ├── base.py │ ├── datasets.py │ ├── duckdb.py │ ├── migrations/ │ │ ├── 001_create_schema_version.sql │ │ ├── 002_add_arguments_field.sql │ │ └── __init__.py │ ├── opentelemetry.py │ └── sql_base.py └── tests/ ├── __init__.py ├── conftest.py ├── integration/ │ └── models/ │ ├── test_async_examples.py │ ├── test_examples.py │ └── test_stream_examples.py └── unit/ └── stores/ └── test_datasets.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/black.yml ================================================ name: Lint on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable ================================================ FILE: .github/workflows/python_tests.yml ================================================ name: Python Tests on: [push, pull_request] jobs: build_and_test: runs-on: ubuntu-latest strategy: matrix: python-version: [ "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install pdm pdm install - name: Test with pytest run: pdm run pytest ================================================ FILE: .gitignore ================================================ # DuckDB Stores *.db # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so *.json # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ /store .db.wal .wal .db ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing ## Development We use [PDM](https://pdm-project.org/en/latest/) to manage dependencies and virtual environments. Make sure you have it installed and then run: ```bash pdm install ``` ## Publishing Configure the PyPI credentials through environment variables `PDM_PUBLISH_USERNAME="__token__"` and `PDM_PUBLISH_PASSWORD=` and run: ```bash pdm publish ``` ================================================ FILE: README.md ================================================

🤗🔭 Observers 🔭🤗

A Lightweight Library for AI Observability

![Observers Logo](./assets/observers.png) ## Installation First things first! You can install the SDK with pip as follows: ```bash pip install observers ``` Or if you want to use other LLM providers through AISuite or Litellm, you can install the SDK with pip as follows: ```bash pip install observers[aisuite] # or observers[litellm] ``` For open telemetry, you can install the following: ```bash pip install observers[opentelemetry] ``` ## Usage We differentiate between observers and stores. Observers wrap generative AI APIs (like OpenAI or llama-index) and track their interactions. Stores are classes that sync these observations to different storage backends (like DuckDB or Hugging Face datasets). To get started you can run the code below. It sends requests to a HF serverless endpoint and log the interactions into a Hub dataset, using the default store `DatasetsStore`. The dataset will be pushed to your personal workspace (http://hf.co/{your_username}). To learn how to configure stores, go to the next section. ```python from openai import OpenAI from observers import wrap_openai openai_client = OpenAI() client = wrap_openai(openai_client) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], ) print(response) ``` ## Observers ### Supported Observers We support both sync and async versions of the following observers: - [OpenAI](https://openai.com/) and every other LLM provider that implements the [OpenAI API message formate](https://platform.openai.com/docs/api-reference) - [Hugging Face transformers](https://huggingface.co/docs/transformers/index), the transformers library by Hugging Face offers a simple API with its [TextGenerationPipeline](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextGenerationPipeline) for running LLM inference. - [Hugging Face Inference Client](https://huggingface.co/docs/huggingface_hub/guides/inference), which is the official client for Hugging Face's [Serverless Inference API](https://huggingface.co/docs/api-inference/en/index), a fast API with a free tier for running LLM inference with models from the Hugging Face Hub. - [AISuite](https://github.com/andrewyng/aisuite), which is an LLM router by Andrew Ng and which maps to [a lot of LLM API providers](https://github.com/andrewyng/aisuite/tree/main/aisuite/providers) with a uniform interface. - [Litellm](https://docs.litellm.ai/docs/), which is a library that allows you to use [a lot of different LLM APIs](https://docs.litellm.ai/docs/providers) with a uniform interface. ### Change OpenAI compliant LLM provider The `wrap_openai` function allows you to wrap any OpenAI compliant LLM provider. Take a look at [the example doing this for Ollama](./examples/observers/ollama_example.py) for more details. ## Stores ### Supported Stores | Store | Example | Annotate | Local | Free | UI filters | SQL filters | |-------|---------|----------|-------|------|-------------|--------------| | [Hugging Face Datasets](https://huggingface.co/docs/huggingface_hub/en/package_reference/io-management#datasets) | [example](./examples/stores/datasets_example.py) | ❌ | ❌ | ✅ | ✅ | ✅ | | [DuckDB](https://duckdb.org/) | [example](./examples/stores/duckdb_example.py) | ❌ | ✅ | ✅ | ❌ | ✅ | | [Argilla](https://argilla.io/) | [example](./examples/stores/argilla_example.py) | ✅ | ❌ | ✅ | ✅ | ❌ | | [OpenTelemetry](https://opentelemetry.io/) | [example](./examples/stores/opentelemetry_example.py) | ︖* | ︖* | ︖* | ︖* | ︖* | | [Honeycomb](https://honeycomb.io/) | [example](./examples/stores/opentelemetry_example.py) | ✅ |❌| ✅ | ✅ | ✅ | * These features, for the OpenTelemetry store, depend upon the provider you use ### Viewing / Querying #### Hugging Face Datasets Store To view and query Hugging Face Datasets, you can use the [Hugging Face Datasets Viewer](https://huggingface.co/docs/hub/en/datasets-viewer). You can [find example datasets on the Hugging Face Hub](https://huggingface.co/datasets?other=observers). From within here, you can query the dataset using SQL or using your own UI. Take a look at [the example](./examples/stores/datasets_example.py) for more details. ![Hugging Face Datasets Viewer](./assets/datasets.png) #### DuckDB Store The default store is [DuckDB](https://duckdb.org/) and can be viewed and queried using the [DuckDB CLI](https://duckdb.org/#quickinstall). Take a look at [the example](./examples/stores/duckdb_example.py) for more details. ```bash > duckdb store.db > from openai_records limit 10; ┌──────────────────────┬──────────────────────┬──────────────────────┬──────────────────────┬───┬─────────┬──────────────────────┬───────────┐ │ id │ model │ timestamp │ messages │ … │ error │ raw_response │ synced_at │ │ varchar │ varchar │ timestamp │ struct("role" varc… │ │ varchar │ json │ timestamp │ ├──────────────────────┼──────────────────────┼──────────────────────┼──────────────────────┼───┼─────────┼──────────────────────┼───────────┤ │ 89cb15f1-d902-4586… │ Qwen/Qwen2.5-Coder… │ 2024-11-19 17:12:3… │ [{'role': user, 'c… │ … │ │ {"id": "", "choice… │ │ │ 415dd081-5000-4d1a… │ Qwen/Qwen2.5-Coder… │ 2024-11-19 17:28:5… │ [{'role': user, 'c… │ … │ │ {"id": "", "choice… │ │ │ chatcmpl-926 │ llama3.1 │ 2024-11-19 17:31:5… │ [{'role': user, 'c… │ … │ │ {"id": "chatcmpl-9… │ │ ├──────────────────────┴──────────────────────┴──────────────────────┴──────────────────────┴───┴─────────┴──────────────────────┴───────────┤ │ 3 rows 16 columns (7 shown) │ └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ ``` #### Argilla Store The Argilla Store allows you to sync your observations to [Argilla](https://argilla.io/). To use it, you first need to create a [free Argilla deployment on Hugging Face](https://docs.argilla.io/latest/getting_started/quickstart/). Take a look at [the example](./examples/stores/argilla_example.py) for more details. ![Argilla Store](./assets/argilla.png) #### OpenTelemetry Store The OpenTelemetry "Store" allows you to sync your observations to any provider that supports OpenTelemetry! Examples are provided for [Honeycomb](https://honeycomb.io), but any provider that supplies OpenTelemetry compatible environment variables should Just Work®, and your queries will be executed as usual in your provider, against _trace_ data coming from Observers. ## Contributing See [CONTRIBUTING.md](./CONTRIBUTING.md) ================================================ FILE: examples/models/aisuite_example.py ================================================ import os import aisuite as ai from observers import wrap_aisuite # Initialize AI Suite client client = ai.Client() # Wrap client to enable tracking client = wrap_aisuite(client) # Set API keys os.environ["ANTHROPIC_API_KEY"] = "your-api-key" os.environ["OPENAI_API_KEY"] = "your-api-key" # Define models to test models = ["openai:gpt-4o", "anthropic:claude-3-5-sonnet-20240620"] # Define conversation messages messages = [ {"role": "system", "content": "Respond in Pirate English."}, {"role": "user", "content": "Tell me a joke."}, ] # Get completions from each model for model in models: response = client.chat.completions.create( model=model, messages=messages, temperature=0.75 ) print(response.choices[0].message.content) ================================================ FILE: examples/models/async_openai_example.py ================================================ import asyncio import os from openai import AsyncOpenAI from observers import wrap_openai openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) client = wrap_openai(openai_client) async def get_response() -> None: response = await client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], ) print(response) if __name__ == "__main__": import asyncio asyncio.run(get_response()) ================================================ FILE: examples/models/hf_client_example.py ================================================ import os from huggingface_hub import InferenceClient import observers api_key = os.getenv("HF_TOKEN") # Patch the HF client hf_client = InferenceClient(token=api_key) client = observers.wrap_hf_client(hf_client) response = client.chat.completions.create( model="Qwen/Qwen2.5-Coder-32B-Instruct", messages=[ { "role": "user", "content": "Write a function in Python that checks if a string is a palindrome.", } ], ) ================================================ FILE: examples/models/litellm_example.py ================================================ import os from litellm import completion from observers import wrap_litellm # Ensure you have both API keys set in environment variables os.environ["OPENAI_API_KEY"] = "your-api-key" os.environ["ANTHROPIC_API_KEY"] = "your-api-key" # Wrap the completion function to enable tracking client = wrap_litellm(completion) # Define models and messages models = ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"] messages = [{"content": "Hello, how are you?", "role": "user"}] # Get completions from each model for model in models: response = client.chat.completions.create( model=model, messages=messages, temperature=0.75 ) print(response.choices[0].message.content) ================================================ FILE: examples/models/ollama_example.py ================================================ from openai import OpenAI from observers import wrap_openai # Ollama is running locally at http://localhost:11434/v1 openai_client = OpenAI(base_url="http://localhost:11434/v1") client = wrap_openai(openai_client) response = client.chat.completions.create( model="llama3.1", messages=[ {"role": "user", "content": "Tell me a joke."}, ], ) print(response) ================================================ FILE: examples/models/openai_example.py ================================================ from openai import OpenAI from observers import wrap_openai openai_client = OpenAI() client = wrap_openai(openai_client) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke in the voice of a pirate."}], temperature=0.5, ) print(response.choices[0].message.content) ================================================ FILE: examples/models/stream_async_hf_client_example.py ================================================ import os from huggingface_hub import AsyncInferenceClient import observers api_key = os.getenv("HF_TOKEN") # Patch the HF client hf_client = AsyncInferenceClient(token=api_key) client = observers.wrap_hf_client(hf_client) async def get_response() -> None: response = await client.chat.completions.create( model="Qwen/Qwen2.5-Coder-32B-Instruct", messages=[ { "role": "user", "content": "Write a function in Python that checks if a string is a palindrome.", } ], stream=True, ) async for chunk in response: print(chunk) if __name__ == "__main__": import asyncio asyncio.run(get_response()) ================================================ FILE: examples/models/stream_openai_example.py ================================================ import asyncio import os from openai import AsyncOpenAI from observers import wrap_openai openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) client = wrap_openai(openai_client) async def get_response() -> None: response = await client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], stream=True, ) async for chunk in response: print(chunk) if __name__ == "__main__": import asyncio asyncio.run(get_response()) ================================================ FILE: examples/models/transformers_example.py ================================================ import os from transformers import pipeline import observers token = os.getenv("HF_TOKEN") pipe = pipeline( "text-generation", model="Qwen/Qwen2.5-0.5B-Instruct", token=token, ) client = observers.wrap_transformers(pipe) messages = [ { "role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!", }, {"role": "user", "content": "Who are you?"}, ] response = client.chat.completions.create( messages=messages, max_new_tokens=256, ) print(response) ================================================ FILE: examples/openai_function_calling_example.py ================================================ from openai import OpenAI from observers import wrap_openai from observers.stores import DatasetsStore store = DatasetsStore( repo_name="gpt-4o-function-calling-traces", every=5, # sync every 5 minutes ) openai_client = OpenAI() tools = [ { "type": "function", "function": { "name": "get_delivery_date", "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'", "parameters": { "type": "object", "properties": { "order_id": { "type": "string", "description": "The customer's order ID.", }, }, "required": ["order_id"], "additionalProperties": False, }, }, } ] messages = [ { "role": "system", "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.", }, { "role": "user", "content": "Hi, can you tell me the delivery date for my order? It's order 1234567890.", }, ] client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o", messages=messages, tools=tools, ) ================================================ FILE: examples/stores/argilla_example.py ================================================ from argilla import TextQuestion # noqa from observers import wrap_openai from observers.stores import ArgillaStore from openai import OpenAI api_url = "" api_key = "" store = ArgillaStore( api_url=api_url, api_key=api_key, # questions=[TextQuestion(name="text")], optional ) openai_client = OpenAI() client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], ) print(response.choices[0].message.content) ================================================ FILE: examples/stores/datasets_example.py ================================================ from observers import wrap_openai from observers.stores import DatasetsStore from openai import OpenAI store = DatasetsStore( repo_name="gpt-4o-traces", every=5, # sync every 5 minutes ) openai_client = OpenAI() client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], ) print(response.choices[0].message.content) ================================================ FILE: examples/stores/duckdb_example.py ================================================ from observers import wrap_openai from observers.stores import DuckDBStore from openai import OpenAI store = DuckDBStore() openai_client = OpenAI() client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}], ) ================================================ FILE: examples/stores/opentelemetry_example.py ================================================ import os from openai import OpenAI from observers import wrap_openai from observers.stores.opentelemetry import OpenTelemetryStore # Use your usual environment variables to configure OpenTelemetry # Here's an example for Honeycomb os.environ.setdefault("OTEL_SERVICE_NAME", "llm-observer-example") os.environ.setdefault("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") os.environ.setdefault("OTEL_EXPORTER_OTLP_ENDPOINT", "https://api.honeycomb.io") # Note: Keeping the sensitive ingest key in actual env vars, not in code # export OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=" store = OpenTelemetryStore() openai_client = OpenAI() client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me a joke."}] ) # The OpenTelemetryStore links multiple completions into a trace response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Tell me another joke."}] ) # Now query your Opentelemetry Compatible observability store as you usually do! ================================================ FILE: examples/vision_example.py ================================================ from openai import OpenAI from observers import wrap_openai from observers.stores import DatasetsStore store = DatasetsStore( repo_name="gpt-4o-mini-vision-traces", every=5, # sync every 5 minutes ) openai_client = OpenAI() client = wrap_openai(openai_client, store=store) response = client.chat.completions.create( model="gpt-4o-mini", messages=[ { "role": "user", "content": [ {"type": "text", "text": "What’s in this image?"}, { "type": "image_url", "image_url": { "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", }, }, ], } ], max_tokens=300, ) print(response.choices[0].message.content) ================================================ FILE: pyproject.toml ================================================ [project] name = "observers" version = "0.2.0" description = "🤗 Observers: A Lightweight Library for AI Observability" authors = [ {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"}, ] tags = [ "observability", "monitoring", "logging", "model-monitoring", "model-observability", "generative-ai", "ai", "traceability", "instrumentation", "instrumentation-library", "instrumentation-sdk", ] requires-python = "<3.13,>=3.10" readme = "README.md" license = {text = "Apache 2"} dependencies = [ "duckdb>=1.0.0", "datasets>=3.0.0", "openai>=1.50.0", "argilla>=2.3.0", ] [project.optional-dependencies] aisuite = [ "aisuite[all]>=0.1.6", ] dev = [ "pytest>=8.3.3", "black>=24.10.0", "jinja2>=3.1.4", "pytest-asyncio>=0.25.1", ] litellm = [ "litellm>=1.52", ] transformers = [ "transformers>=4.46.0", "torch>=2", ] opentelemetry = [ "opentelemetry-api>=1.28.0", "opentelemetry-sdk>=1.28.0", "opentelemetry-exporter-otlp-proto-grpc>=1.28.0", ] [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" [tool.pdm] distribution = true ================================================ FILE: src/observers/__init__.py ================================================ from typing import List from .models.aisuite import wrap_aisuite from .models.base import ChatCompletionObserver, ChatCompletionRecord from .models.hf_client import wrap_hf_client from .models.litellm import wrap_litellm from .models.openai import OpenAIRecord, wrap_openai from .models.transformers import TransformersRecord, wrap_transformers from .stores.base import Store from .stores.datasets import DatasetsStore __all__: List[str] = [ "ChatCompletionObserver", "ChatCompletionRecord", "TransformersRecord", "OpenAIRecord", "wrap_openai", "wrap_transformers", "DatasetsStore", "Store", "wrap_aisuite", "wrap_litellm", "wrap_hf_client", "ArgillaStore", "DuckDBStore", ] ================================================ FILE: src/observers/base.py ================================================ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing_extensions import Literal if TYPE_CHECKING: from argilla import Argilla @dataclass class Function: """Function tool call information""" name: str arguments: str @dataclass class ToolCall: """Tool call information""" id: str type: Literal["function"] function: Function @dataclass class Message: role: Literal["system", "user", "assistant", "function"] content: str tool_calls: Optional[List[ToolCall]] = None """The tool calls generated by the model, such as function calls.""" function_call: Optional[Function] = None """Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model. """ @dataclass class Record(ABC): """ Base class for storing model response information """ client_name: str = field(init=False) id: str = field(default_factory=lambda: str(uuid.uuid4())) tags: List[str] = None properties: Dict[str, Any] = None error: Optional[str] = None raw_response: Optional[Dict] = None @property @abstractmethod def json_fields(self): """Return the DuckDB JSON fields for the record""" pass @property @abstractmethod def image_fields(self): """Return the DuckDB image fields for the record""" pass @property @abstractmethod def table_columns(self): """Return the DuckDB table columns for the record""" pass @property @abstractmethod def duckdb_schema(self): """Return the DuckDB schema for the record""" pass @property @abstractmethod def table_name(self): """Return the DuckDB table name for the record""" pass @abstractmethod def argilla_settings(self, client: "Argilla"): """Return the Argilla settings for the record""" pass ================================================ FILE: src/observers/frameworks/__init__.py ================================================ ================================================ FILE: src/observers/models/__init__.py ================================================ ================================================ FILE: src/observers/models/aisuite.py ================================================ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from observers.models.base import ( AsyncChatCompletionObserver, ChatCompletionObserver, ) from observers.models.openai import OpenAIRecord if TYPE_CHECKING: from aisuite import Client from observers.stores.argilla import ArgillaStore from observers.stores.datasets import DatasetsStore from observers.stores.duckdb import DuckDBStore class AisuiteRecord(OpenAIRecord): client_name: str = "aisuite" def wrap_aisuite( client: "Client", store: Optional[Union["DatasetsStore", "DuckDBStore", "ArgillaStore"]] = None, tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, ) -> Union[AsyncChatCompletionObserver, ChatCompletionObserver]: """Wraps Aisuite client to track API calls in a Store. Args: client (`Union[InferenceClient, AsyncInferenceClient]`): The HF Inference Client to wrap. store (`Union[DuckDBStore, DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 Returns: `ChatCompletionObserver`: The observer that wraps the Aisuite client. """ return ChatCompletionObserver( client=client, create=client.chat.completions.create, format_input=lambda messages, **kwargs: kwargs | {"messages": messages}, parse_response=AisuiteRecord.from_response, store=store, tags=tags, properties=properties, logging_rate=logging_rate, ) ================================================ FILE: src/observers/models/base.py ================================================ import datetime import random from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from typing_extensions import Self from observers.base import Message, Record from observers.stores.datasets import DatasetsStore if TYPE_CHECKING: from argilla import Argilla from observers.stores.duckdb import DuckDBStore @dataclass class ChatCompletionRecord(Record): """ Data class for storing chat completion records. """ model: str = None timestamp: str = field(default_factory=lambda: datetime.datetime.now().isoformat()) arguments: Optional[Dict[str, Any]] = None messages: List[Message] = None assistant_message: Optional[str] = None completion_tokens: Optional[int] = None prompt_tokens: Optional[int] = None total_tokens: Optional[int] = None finish_reason: str = None tool_calls: Optional[Any] = None function_call: Optional[Any] = None @classmethod def from_response(cls, response=None, error=None, model=None, **kwargs): """Create a response record from an API response or error""" pass @property def table_columns(self): return [ "id", "model", "timestamp", "messages", "assistant_message", "completion_tokens", "prompt_tokens", "total_tokens", "finish_reason", "tool_calls", "function_call", "tags", "properties", "error", "raw_response", "arguments", ] @property def duckdb_schema(self): return f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id VARCHAR PRIMARY KEY, model VARCHAR, timestamp TIMESTAMP, messages JSON, assistant_message TEXT, completion_tokens INTEGER, prompt_tokens INTEGER, total_tokens INTEGER, finish_reason VARCHAR, tool_calls JSON, function_call JSON, tags VARCHAR[], properties JSON, error VARCHAR, raw_response JSON, arguments JSON, ) """ def argilla_settings(self, client: "Argilla"): import argilla as rg from argilla import Settings return Settings( fields=[ rg.ChatField( name="messages", description="The messages sent to the assistant.", _client=client, ), rg.TextField( name="assistant_message", description="The response from the assistant.", required=False, client=client, ), rg.CustomField( name="tool_calls", template="{{ json record.fields.tool_calls }}", description="The tool calls made by the assistant.", required=False, _client=client, ), rg.CustomField( name="function_call", template="{{ json record.fields.function_call }}", description="The function call made by the assistant.", required=False, _client=client, ), rg.CustomField( name="properties", template="{{ json record.fields.properties }}", description="The properties associated with the response.", required=False, _client=client, ), rg.CustomField( name="raw_response", template="{{ json record.fields.raw_response }}", description="The raw response from the API.", required=False, _client=client, ), ], questions=[ rg.RatingQuestion( name="rating", description="How would you rate the response? 1 being the worst and 5 being the best.", values=[1, 2, 3, 4, 5], ), rg.TextQuestion( name="improved_response", description="If you would like to improve the response, please provide a better response here.", required=False, ), rg.TextQuestion( name="context", description="If you would like to provide more context for the response or rating, please provide it here.", required=False, ), ], metadata=[ rg.IntegerMetadataProperty(name="completion_tokens", client=client), rg.IntegerMetadataProperty(name="prompt_tokens", client=client), rg.IntegerMetadataProperty(name="total_tokens", client=client), rg.TermsMetadataProperty(name="model", client=client), rg.TermsMetadataProperty(name="finish_reason", client=client), rg.TermsMetadataProperty(name="tags", client=client), ], ) @property def table_name(self): return f"{self.client_name}_records" @property def json_fields(self): return [ "tool_calls", "function_call", "tags", "properties", "raw_response", "arguments", ] @property def image_fields(self): return [] @property def text_fields(self): return [] class ChatCompletionObserver: """ Observer that provides an interface for tracking chat completions. Args: client (`Any`): The client to use for the chat completions. create (`Callable[..., Any]`): The function to use to create the chat completions., eg `chat.completions.create` for OpenAI client. format_input (`Callable[[Dict[str, Any], Any], Any]`): The function to use to format the input messages. parse_response (`Callable[[Any], Dict[str, Any]]`): The function to use to parse the response. store (`Union["DuckDBStore", DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 """ def __init__( self, client: Any, create: Callable[..., Any], format_input: Callable[[Dict[str, Any], Any], Any], parse_response: Callable[[Any], Dict[str, Any]], store: Optional[Union["DuckDBStore", DatasetsStore]] = None, tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, **kwargs: Any, ): self.client = client self.create_fn = create self.format_input = format_input self.parse_response = parse_response self.store = store or DatasetsStore.connect() self.tags = tags or [] self.properties = properties or {} self.kwargs = kwargs self.logging_rate = logging_rate @property def chat(self) -> Self: return self @property def completions(self) -> Self: return self def _log_record( self, response, error=None, model=None, messages=None, arguments=None ): record = self.parse_response( response, error=error, model=model, messages=messages, tags=self.tags, properties=self.properties, arguments=arguments, ) if random.random() < self.logging_rate: self.store.add(record) return record def create( self, messages: Dict[str, Any], **kwargs: Any, ) -> Any: """Creates a completion. Args: messages (`Dict[str, Any]`): The messages to send to the assistant. **kwargs: Additional arguments passed to the create function. If stream=True is passed, the function will return a generator yielding streamed responses. Returns: Any: The response from the assistant, or a generator if streaming. """ response = None kwargs = self.handle_kwargs(kwargs) excluded_args = {"model", "messages", "tags", "properties"} arguments = {k: v for k, v in kwargs.items() if k not in excluded_args} model = kwargs.get("model") input_data = self.format_input(messages, **kwargs) if kwargs.get("stream", False): def stream_responses(): response_buffer = [] try: for chunk in self.create_fn(**input_data): yield chunk response_buffer.append(chunk) self._log_record( response_buffer, model=model, messages=messages, arguments=arguments, ) except Exception as e: self._log_record( response_buffer, error=e, model=model, messages=messages, arguments=arguments, ) raise return stream_responses() try: response = self.create_fn(**input_data) self._log_record( response, model=model, messages=messages, arguments=arguments ) return response except Exception as e: self._log_record( response, error=e, model=model, messages=messages, arguments=arguments ) raise def handle_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """ Handle and process keyword arguments for the API call. This method merges the provided kwargs with the default kwargs stored in the instance. It ensures that any kwargs passed to the method call take precedence over the default ones. """ return {**self.kwargs, **kwargs} def __getattr__(self, attr: str) -> Any: if attr not in {"create", "chat", "messages"}: return getattr(self.client, attr) return getattr(self, attr) class AsyncChatCompletionObserver(ChatCompletionObserver): """ Async observer that provides an interface for tracking chat completions Args: client (`Any`): The async client to use for the chat completions. create (`Callable[..., Awaitable[Any]]`): The async function to use to create the chat completions. format_input (`Callable[[Dict[str, Any], Any], Any]`): The function to use to format the input messages. parse_response (`Callable[[Any], Dict[str, Any]]`): The function to use to parse the response. store (`Union["DuckDBStore", DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to include in the records. properties (`Dict[str, Any]`, *optional*): The properties to include in the records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 """ async def _log_record_async( self, response, error=None, model=None, messages=None, arguments=None ): record = self.parse_response( response, error=error, model=model, messages=messages, tags=self.tags, properties=self.properties, arguments=arguments, ) if random.random() < self.logging_rate: await self.store.add_async(record) return record async def create( self, messages: Dict[str, Any], **kwargs: Any, ) -> Any: """Create an async completion. Args: messages (`Dict[str, Any]`): The messages to send to the assistant. Returns: Any: The response from the assistant. """ response = None kwargs = self.handle_kwargs(kwargs) excluded_args = {"model", "messages", "tags", "properties"} arguments = {k: v for k, v in kwargs.items() if k not in excluded_args} model = kwargs.get("model") input_data = self.format_input(messages, **kwargs) if kwargs.get("stream", False): async def stream_responses(): response_buffer = [] try: async for chunk in await self.create_fn(**input_data): yield chunk response_buffer.append(chunk) await self._log_record_async( response_buffer, model=model, messages=messages, arguments=arguments, ) except Exception as e: await self._log_record_async( response_buffer, error=e, model=model, messages=messages, arguments=arguments, ) raise return stream_responses() try: response = await self.create_fn(**input_data) await self._log_record_async( response, model=model, messages=messages, arguments=arguments ) return response except Exception as e: await self._log_record_async( response, error=e, model=model, messages=messages, arguments=arguments ) raise async def __aenter__(self) -> "AsyncChatCompletionObserver": return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.store.close_async() ================================================ FILE: src/observers/models/hf_client.py ================================================ import uuid from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from huggingface_hub import AsyncInferenceClient, InferenceClient from observers.models.base import ( AsyncChatCompletionObserver, ChatCompletionObserver, ChatCompletionRecord, ) if TYPE_CHECKING: from huggingface_hub import ( ChatCompletionOutput, ChatCompletionStreamOutput, ) from observers.stores.datasets import DatasetsStore from observers.stores.duckdb import DuckDBStore class HFRecord(ChatCompletionRecord): client_name: str = "hf_client" @classmethod def from_response( cls, response: Union[ None, List["ChatCompletionStreamOutput"], "ChatCompletionOutput", ] = None, error=None, **kwargs, ) -> "HFRecord": """Create a response record from an API response or error Args: response: The response from the API. error: The error from the API. **kwargs: Additional arguments passed to the observer. """ if not response: return cls(finish_reason="error", error=str(error), **kwargs) # Handle streaming responses if isinstance(response, list): first_dump = asdict(response[0]) last_dump = asdict(response[-1]) id = first_dump.get("id") or str(uuid.uuid4()) choices = last_dump.get("choices", [{}])[0] delta = choices.get("delta", {}) content = "" total_tokens = prompt_tokens = completion_tokens = 0 raw_response = {} for i, r in enumerate(response): r_dump = asdict(r) raw_response[i] = r_dump usage = r_dump.get("usage", {}) total_tokens += usage.get("total_tokens", 0) prompt_tokens += usage.get("prompt_tokens", 0) completion_tokens += usage.get("completion_tokens", 0) content += ( r_dump.get("choices", [{}])[0].get("delta", {}).get("content") or "" ) return cls( id=id, completion_tokens=completion_tokens, prompt_tokens=prompt_tokens, total_tokens=total_tokens, assistant_message=content, finish_reason=choices.get("finish_reason"), tool_calls=delta.get("tool_calls"), function_call=delta.get("function_call"), raw_response=raw_response, **kwargs, ) # Handle non-streaming responses response_dump = asdict(response) choices = response_dump.get("choices", [{}])[0].get("message", {}) usage = response_dump.get("usage", {}) return cls( id=response_dump.get("id") or str(uuid.uuid4()), completion_tokens=usage.get("completion_tokens"), prompt_tokens=usage.get("prompt_tokens"), total_tokens=usage.get("total_tokens"), assistant_message=choices.get("content"), finish_reason=response_dump.get("choices", [{}])[0].get("finish_reason"), tool_calls=choices.get("tool_calls"), function_call=choices.get("function_call"), raw_response=response_dump, **kwargs, ) def wrap_hf_client( client: Union["InferenceClient", "AsyncInferenceClient"], store: Optional[Union["DuckDBStore", "DatasetsStore"]] = None, tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, ) -> Union["AsyncChatCompletionObserver", "ChatCompletionObserver"]: """ Wraps Hugging Face's Inference Client in an observer. Args: client (`Union[InferenceClient, AsyncInferenceClient]`): The HF Inference Client to wrap. store (`Union[DuckDBStore, DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 Returns: `Union[AsyncChatCompletionObserver, ChatCompletionObserver]`: The observer that wraps the HF Inference Client. """ observer_args = { "client": client, "create": client.chat.completions.create, "format_input": lambda inputs, **kwargs: {"messages": inputs, **kwargs}, "parse_response": HFRecord.from_response, "store": store, "tags": tags, "properties": properties, "logging_rate": logging_rate, } if isinstance(client, AsyncInferenceClient): return AsyncChatCompletionObserver(**observer_args) return ChatCompletionObserver(**observer_args) ================================================ FILE: src/observers/models/litellm.py ================================================ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from observers.models.base import ( AsyncChatCompletionObserver, ChatCompletionObserver, ) from observers.models.openai import OpenAIRecord if TYPE_CHECKING: from litellm import acompletion, completion from observers.stores.argilla import ArgillaStore from observers.stores.datasets import DatasetsStore from observers.stores.duckdb import DuckDBStore class LitellmRecord(OpenAIRecord): client_name: str = "litellm" def wrap_litellm( client: Union["completion", "acompletion"], store: Optional[Union["DatasetsStore", "DuckDBStore", "ArgillaStore"]] = None, tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, ) -> Union[AsyncChatCompletionObserver, ChatCompletionObserver]: """ Wrap Litellm completion function to track API calls in a Store. Args: client (`Union[InferenceClient, AsyncInferenceClient]`): The HF Inference Client to wrap. store (`Union[DuckDBStore, DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 Returns: `Union[AsyncChatCompletionObserver, ChatCompletionObserver]`: The observer that wraps the Litellm completion function. """ observer_args = { "client": client, "create": client, "format_input": lambda inputs, **kwargs: {"messages": inputs, **kwargs}, "parse_response": LitellmRecord.from_response, "store": store, "tags": tags, "properties": properties, "logging_rate": logging_rate, } if client.__name__ == "acompletion": return AsyncChatCompletionObserver(**observer_args) return ChatCompletionObserver(**observer_args) ================================================ FILE: src/observers/models/openai.py ================================================ import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from observers.stores.duckdb import DuckDBStore from openai import AsyncOpenAI, OpenAI from typing_extensions import Self from observers.models.base import ( AsyncChatCompletionObserver, ChatCompletionObserver, ChatCompletionRecord, ) if TYPE_CHECKING: from openai.types.chat import ChatCompletion, ChatCompletionChunk from observers.stores.datasets import DatasetsStore class OpenAIRecord(ChatCompletionRecord): client_name: str = "openai" @classmethod def from_response( cls, response: Union[List["ChatCompletionChunk"], "ChatCompletion"] = None, error=None, messages=None, **kwargs, ) -> Self: """Create a response record from an API response or error""" if not response: return cls( finish_reason="error", error=str(error), messages=messages, **kwargs ) # Handle streaming responses if isinstance(response, list): first_dump = response[0].model_dump() last_dump = response[-1].model_dump() content = "" completion_tokens = prompt_tokens = total_tokens = 0 choices = last_dump.get("choices", [{}])[0] delta = choices.get("delta", {}) raw_response = {} for i, r in enumerate(response): r_dump = r.model_dump() raw_response[i] = r_dump content += ( r_dump.get("choices", [{}])[0].get("delta", {}).get("content") or "" ) usage = r_dump.get("usage", {}) or {} completion_tokens += usage.get("completion_tokens", 0) prompt_tokens += usage.get("prompt_tokens", 0) total_tokens += usage.get("total_tokens", 0) return cls( id=first_dump.get("id") or str(uuid.uuid4()), messages=messages, completion_tokens=completion_tokens, prompt_tokens=prompt_tokens, total_tokens=total_tokens, assistant_message=content, finish_reason=choices.get("finish_reason"), tool_calls=delta.get("tool_calls"), function_call=delta.get("function_call"), raw_response=raw_response, **kwargs, ) # Handle non-streaming responses response_dump = response.model_dump() choices = response_dump.get("choices", [{}])[0].get("message", {}) usage = response_dump.get("usage", {}) or {} return cls( id=response.id or str(uuid.uuid4()), messages=messages, completion_tokens=usage.get("completion_tokens"), prompt_tokens=usage.get("prompt_tokens"), total_tokens=usage.get("total_tokens"), assistant_message=choices.get("content"), finish_reason=response_dump.get("choices", [{}])[0].get("finish_reason"), tool_calls=choices.get("tool_calls"), function_call=choices.get("function_call"), raw_response=response_dump, **kwargs, ) def wrap_openai( client: Union["OpenAI", "AsyncOpenAI"], store: Optional[Union["DuckDBStore", "DatasetsStore"]] = DuckDBStore(), tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, ) -> Union[ChatCompletionObserver, AsyncChatCompletionObserver]: """ Wraps an OpenAI client in an observer. Args: client (`Union[OpenAI, AsyncOpenAI]`): The OpenAI client to wrap. store (`Union[DuckDBStore, DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 Returns: `Union[ChatCompletionObserver, AsyncChatCompletionObserver]`: The observer that wraps the OpenAI client. """ observer_args = { "client": client, "create": client.chat.completions.create, "format_input": lambda messages, **kwargs: kwargs | {"messages": messages}, "parse_response": OpenAIRecord.from_response, "store": store, "tags": tags, "properties": properties, "logging_rate": logging_rate, } if isinstance(client, AsyncOpenAI): return AsyncChatCompletionObserver(**observer_args) return ChatCompletionObserver(**observer_args) ================================================ FILE: src/observers/models/transformers.py ================================================ import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from observers.models.base import ( ChatCompletionObserver, ChatCompletionRecord, ) if TYPE_CHECKING: from transformers import TextGenerationPipeline from observers.stores.datasets import DatasetsStore from observers.stores.duckdb import DuckDBStore class TransformersRecord(ChatCompletionRecord): """ Data class for storing transformer records. """ client_name: str = "transformers" @classmethod def from_response( cls, response: Dict[str, Any] = None, error: Exception = None, model: Optional[str] = None, **kwargs, ) -> "TransformersRecord": if not response: return cls(finish_reason="error", error=str(error), **kwargs) generated_text = response[0]["generated_text"][-1] return cls( id=str(uuid.uuid4()), assistant_message=generated_text.get("content"), tool_calls=generated_text.get("tool_calls"), raw_response=response, **kwargs, ) def wrap_transformers( client: "TextGenerationPipeline", store: Optional[Union["DuckDBStore", "DatasetsStore"]] = None, tags: Optional[List[str]] = None, properties: Optional[Dict[str, Any]] = None, logging_rate: Optional[float] = 1, ) -> ChatCompletionObserver: """ Wraps a transformers client in an observer. Args: client (`transformers.TextGenerationPipeline`): The transformers pipeline to wrap. store (`Union[DuckDBStore, DatasetsStore]`, *optional*): The store to use to save the records. tags (`List[str]`, *optional*): The tags to associate with records. properties (`Dict[str, Any]`, *optional*): The properties to associate with records. logging_rate (`float`, *optional*): The logging rate to use for logging, defaults to 1 Returns: `ChatCompletionObserver`: The observer that wraps the transformers pipeline. """ return ChatCompletionObserver( client=client, create=client.__call__, format_input=lambda inputs, **kwargs: {"text_inputs": inputs, **kwargs}, parse_response=TransformersRecord.from_response, store=store, tags=tags, properties=properties, logging_rate=logging_rate, ) ================================================ FILE: src/observers/stores/__init__.py ================================================ from observers.stores.argilla import ArgillaStore from observers.stores.datasets import DatasetsStore from observers.stores.duckdb import DuckDBStore __all__ = ["ArgillaStore", "DatasetsStore", "DuckDBStore"] ================================================ FILE: src/observers/stores/argilla.py ================================================ import uuid from dataclasses import asdict, dataclass, field from typing import TYPE_CHECKING, List, Optional, Union import argilla as rg from argilla import ( Argilla, LabelQuestion, MultiLabelQuestion, RankingQuestion, RatingQuestion, SpanQuestion, TextQuestion, ) from observers.stores.base import Store if TYPE_CHECKING: from observers.base import Record @dataclass class ArgillaStore(Store): """ Argilla store """ api_url: Optional[str] = field(default=None) api_key: Optional[str] = field(default=None) dataset_name: Optional[str] = field(default=None) workspace_name: Optional[str] = field(default=None) questions: Optional[ List[ Union[ TextQuestion, LabelQuestion, SpanQuestion, RatingQuestion, RankingQuestion, MultiLabelQuestion, ] ] ] = field(default=None) _dataset: Optional[rg.Dataset] = None _dataset_keys: Optional[List[str]] = None _client: Optional[Argilla] = None def __post_init__(self) -> None: """Initialize the store""" self._client = Argilla(api_url=self.api_url, api_key=self.api_key) def _init_table(self, record: "Record") -> None: dataset_name = ( self.dataset_name or f"{record.table_name}_{uuid.uuid4().hex[:8]}" ) workspace_name = self.workspace_name or self._client.me.username workspace = self._client.workspaces(name=workspace_name) if not workspace: workspace = self._client.workspaces.add(rg.Workspace(name=workspace_name)) dataset = self._client.datasets(name=dataset_name, workspace=workspace_name) if not dataset: settings = record.argilla_settings(self._client) if self.questions: settings.questions = self.questions dataset = rg.Dataset( name=dataset_name, workspace=workspace_name, settings=settings, client=self._client, ).create() elif self.questions: raise ValueError( "Custom questions are not supported for existing datasets." ) self._dataset = dataset dataset_keys = ( [field.name for field in dataset.settings.fields] + [question.name for question in dataset.settings.questions] + [term.name for term in dataset.settings.metadata] + [vector.name for vector in dataset.settings.vectors] ) self._dataset_keys = dataset_keys @classmethod def connect( cls, api_url: Optional[str] = None, api_key: Optional[str] = None, dataset_name: Optional[str] = None, workspace_name: Optional[str] = None, ) -> "ArgillaStore": """Create a new store instance with custom settings""" return cls( api_url=api_url, api_key=api_key, dataset_name=dataset_name, workspace_name=workspace_name, ) def add(self, record: "Record") -> None: """Add a new record to the database""" if not self._dataset: self._init_table(record) record_dict = asdict(record) for text_field in record.text_fields: if text_field in record_dict: record_dict[f"{text_field}_length"] = len(record_dict[text_field]) record_dict = {k: v for k, v in record_dict.items() if k in self._dataset_keys} self._dataset.records.log([record_dict]) async def add_async(self, record: "Record"): """ Add a new record to the database asynchronously Args: record (`Record`): The record to add to the database. """ if not self._dataset: self._init_table(record) record_dict = asdict(record) for text_field in record.text_fields: if text_field in record_dict: record_dict[f"{text_field}_length"] = len(record_dict[text_field]) record_dict = {k: v for k, v in record_dict.items() if k in self._dataset_keys} # Use argilla's native async API await self._dataset.records.log( [record_dict], background=True, verbose=False, ) ================================================ FILE: src/observers/stores/base.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from observers.base import Record @dataclass class Store(ABC): """ Base class for storing records """ @abstractmethod def add(self, record: "Record"): """Add a new record to the store""" pass @abstractmethod async def add_async(self, record: "Record"): """Add a new record to the store asynchronously""" pass @abstractmethod def connect(self): """Connect to the store""" pass @abstractmethod def _init_table(self, record: "Record"): """Initialize the table""" pass ================================================ FILE: src/observers/stores/datasets.py ================================================ import asyncio import atexit import base64 import hashlib import json import os import tempfile import uuid from dataclasses import asdict, dataclass, field from io import BytesIO from typing import TYPE_CHECKING, List, Optional from datasets.utils.logging import disable_progress_bar from huggingface_hub import CommitScheduler, login, metadata_update, whoami from PIL import Image from observers.stores.base import Store if TYPE_CHECKING: from observers.base import Record disable_progress_bar() @dataclass class DatasetsStore(Store): """ Datasets store """ org_name: Optional[str] = field(default=None) repo_name: Optional[str] = field(default=None) folder_path: Optional[str] = field(default=None) every: Optional[int] = field(default=5) path_in_repo: Optional[str] = field(default=None) revision: Optional[str] = field(default=None) private: Optional[bool] = field(default=None) token: Optional[str] = field(default=None) allow_patterns: Optional[List[str]] = field(default=None) ignore_patterns: Optional[List[str]] = field(default=None) squash_history: Optional[bool] = field(default=None) _filename: Optional[str] = field(default=None) _scheduler: Optional[CommitScheduler] = None _temp_dir: Optional[str] = field(default=None, init=False) def __post_init__(self): """Initialize the store and create temporary directory""" if self.ignore_patterns is None: self.ignore_patterns = ["*.json"] try: whoami(token=self.token or os.getenv("HF_TOKEN")) except Exception: login() if self.folder_path is None: self._temp_dir = tempfile.mkdtemp(prefix="observers_dataset_") self.folder_path = self._temp_dir atexit.register(self._cleanup) else: os.makedirs(self.folder_path, exist_ok=True) def _cleanup(self): """Clean up temporary directory on exit""" if self._temp_dir and os.path.exists(self._temp_dir): import shutil shutil.rmtree(self._temp_dir) def _init_table(self, record: "Record"): import logging logging.getLogger("huggingface_hub").setLevel(logging.ERROR) repo_name = self.repo_name or f"{record.table_name}_{uuid.uuid4().hex[:8]}" org_name = self.org_name or whoami(token=self.token).get("name") repo_id = f"{org_name}/{repo_name}" self._filename = f"{record.table_name}_{uuid.uuid4()}.json" self._scheduler = CommitScheduler( repo_id=repo_id, folder_path=self.folder_path, every=self.every, path_in_repo=self.path_in_repo, repo_type="dataset", revision=self.revision, private=self.private, token=self.token, allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns, squash_history=self.squash_history, ) self._scheduler.private = self.private metadata_update( repo_id=repo_id, metadata={"tags": ["observers", record.table_name.split("_")[0]]}, repo_type="dataset", token=self.token, overwrite=True, ) @classmethod def connect( cls, org_name: Optional[str] = None, repo_name: Optional[str] = None, folder_path: Optional[str] = None, every: Optional[int] = 5, path_in_repo: Optional[str] = None, revision: Optional[str] = None, private: Optional[bool] = None, token: Optional[str] = None, allow_patterns: Optional[List[str]] = None, ignore_patterns: Optional[List[str]] = None, squash_history: Optional[bool] = None, ) -> "DatasetsStore": """Create a new store instance with optional custom path""" return cls( org_name=org_name, repo_name=repo_name, folder_path=folder_path, every=every, path_in_repo=path_in_repo, revision=revision, private=private, token=token, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, squash_history=squash_history, ) def add(self, record: "Record"): """Add a new record to the database""" if not self._scheduler: self._init_table(record) with self._scheduler.lock: with (self._scheduler.folder_path / self._filename).open("a") as f: record_dict = asdict(record) # Handle JSON fields for json_field in record.json_fields: if record_dict[json_field]: record_dict[json_field] = json.dumps(record_dict[json_field]) # Handle image fields for image_field in record.image_fields: if record_dict[image_field]: image_folder = self._scheduler.folder_path / "images" image_folder.mkdir(exist_ok=True) # Generate unique filename based on record content filtered_dict = { k: v for k, v in sorted(record_dict.items()) if k not in ["uri", image_field, "id"] } content_hash = hashlib.sha256( json.dumps(obj=filtered_dict, sort_keys=True).encode() ).hexdigest() image_path = image_folder / f"{content_hash}.png" # Save image and update record image_bytes = base64.b64decode( record_dict[image_field]["bytes"] ) Image.open(BytesIO(image_bytes)).save(image_path) record_dict[image_field].update( {"path": str(image_path), "bytes": None} ) # Clean up empty dictionaries record_dict = { k: None if v == {} else v for k, v in record_dict.items() } sorted_dict = { col: record_dict.get(col) for col in record.table_columns } try: f.write(json.dumps(sorted_dict) + "\n") f.flush() except Exception: raise async def add_async(self, record: "Record"): """Add a new record to the database asynchronously""" await asyncio.to_thread(self.add, record) async def close_async(self): """Close the dataset store asynchronously""" if self._scheduler: await asyncio.to_thread(self._scheduler.__exit__, None, None, None) self._scheduler = None def close(self): """Close the dataset store synchronously""" if self._scheduler: self._scheduler.__exit__(None, None, None) self._scheduler = None ================================================ FILE: src/observers/stores/duckdb.py ================================================ import asyncio import glob import json import os import re from dataclasses import asdict, dataclass, field from pathlib import Path from typing import TYPE_CHECKING, List, Optional import duckdb from observers.stores.sql_base import SQLStore if TYPE_CHECKING: from observers.base import Record DEFAULT_DB_NAME = "store.db" @dataclass class DuckDBStore(SQLStore): """ DuckDB store """ path: str = field( default_factory=lambda: os.path.join(os.getcwd(), DEFAULT_DB_NAME) ) _tables: List[str] = field(default_factory=list) _conn: Optional[duckdb.DuckDBPyConnection] = None def __post_init__(self): """Initialize database connection and table""" if self._conn is None: self._conn = duckdb.connect(self.path) self._tables = self._get_tables() self._get_current_schema_version() self._apply_pending_migrations() @classmethod def connect(cls, path: Optional[str] = None) -> "DuckDBStore": """Create a new store instance with optional custom path""" if not path: path = os.path.join(os.getcwd(), DEFAULT_DB_NAME) return cls(path=path) def _init_table(self, record: "Record") -> str: self._conn.execute(record.duckdb_schema) self._tables.append(record.table_name) def _get_tables(self) -> List[str]: """Get all tables in the database""" return [table[0] for table in self._conn.execute("SHOW TABLES").fetchall()] def add(self, record: "Record"): """Add a new record to the database""" if record.table_name not in self._tables: self._init_table(record) record_dict = asdict(record) for json_field in record.json_fields: if record_dict[json_field]: record_dict[json_field] = json.dumps(record_dict[json_field]) placeholders = ", ".join( ["$" + str(i + 1) for i in range(len(record.table_columns))] ) # Sort record_dict based on table_columns order if hasattr(record, "table_columns"): sorted_dict = {k: record_dict[k] for k in record.table_columns} record_dict = sorted_dict self._conn.execute( f"INSERT INTO {record.table_name} VALUES ({placeholders})", [ record_dict[k] if k in record_dict else None for k in record.table_columns ], ) async def add_async(self, record: "Record"): """Add a new record to the database asynchronously""" await asyncio.to_thread(self.add, record) def close(self) -> None: """Close the database connection""" if self._conn: self._conn.close() self._conn = None def __enter__(self): return self def _migrate_schema(self, migration_script: str): """Apply a schema migration""" self._conn.execute(migration_script) def _get_current_schema_version(self) -> int: """Get the current schema version, creating the table if it doesn't exist""" table_exists = self._conn.execute( "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'schema_version'" ).fetchone()[0] # create the schema_version table if it doesn't exist if not table_exists: self._conn.execute( """ CREATE TABLE schema_version ( version INTEGER PRIMARY KEY, migration_name VARCHAR, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) self._conn.execute( "INSERT INTO schema_version (version, migration_name) VALUES (0, 'initial')" ) # retrieve the current schema version result = self._conn.execute( "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1" ).fetchone() return result[0] if result else 0 def __exit__(self, exc_type, exc_val, exc_tb): self.close() def _get_migrations_path(self) -> Path: """Get the path to migrations directory""" return Path(__file__).parent / "migrations" def _get_available_migrations(self) -> List[tuple[int, str]]: """Get all available migration files sorted by version""" migrations_path = self._get_migrations_path() migration_files = glob.glob(str(migrations_path / "*.sql")) # extract version and path using regex migrations = [] for file_path in migration_files: # Match migration files in format: any_prefix_NUMBER_any_suffix.sql # e.g., "001_create_users.sql" or "v1_init.sql" - extracts "1" as version if match := re.match(r".*?(\d+)_.+\.sql$", file_path): version = int(match.group(1)) migrations.append((version, file_path)) return sorted(migrations) def _apply_pending_migrations(self): """Apply any pending migrations""" current_version = self._get_current_schema_version() available_migrations = self._get_available_migrations() for version, migration_path in available_migrations: if version > current_version: with open(migration_path, "r") as f: migration_script = f.read() migration_name = Path( migration_path ).stem # Gets filename without extension self._conn.execute("BEGIN TRANSACTION") try: self._migrate_schema(migration_script) self._conn.execute( "INSERT INTO schema_version (version, migration_name) VALUES (?, ?)", [version, migration_name], ) self._conn.execute("COMMIT") except Exception as e: self._conn.execute("ROLLBACK") raise Exception(f"Migration {version} failed: {str(e)}") def _check_table_exists(self, table_name: str) -> bool: """Check if a table exists in the database""" result = self._conn.execute( "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", [table_name], ).fetchone()[0] return bool(result) def _create_version_table(self): """Create the schema version table""" self._conn.execute( """ CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, migration_name VARCHAR, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) def _execute(self, query: str, params: Optional[List] = None): """Execute a SQL query""" return self._conn.execute(query, params if params else []) ================================================ FILE: src/observers/stores/migrations/001_create_schema_version.sql ================================================ CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, migration_name VARCHAR, checksum VARCHAR ); CREATE TABLE IF NOT EXISTS openai_records ( id VARCHAR PRIMARY KEY, model VARCHAR, timestamp TIMESTAMP, messages JSON, assistant_message TEXT, completion_tokens INTEGER, prompt_tokens INTEGER, total_tokens INTEGER, finish_reason VARCHAR, tool_calls JSON, function_call JSON, tags VARCHAR[], properties JSON, error VARCHAR, raw_response JSON, arguments JSON ); -- Initialize with version 0 if table is empty INSERT INTO schema_version (version, migration_name) SELECT 0, 'initial' WHERE NOT EXISTS (SELECT 1 FROM schema_version); ================================================ FILE: src/observers/stores/migrations/002_add_arguments_field.sql ================================================ ALTER TABLE IF EXISTS openai_records ADD COLUMN IF NOT EXISTS arguments JSON; ALTER TABLE IF EXISTS openai_records DROP COLUMN IF EXISTS synced_at; ================================================ FILE: src/observers/stores/migrations/__init__.py ================================================ ================================================ FILE: src/observers/stores/opentelemetry.py ================================================ # stdlib features import asyncio from dataclasses import dataclass from importlib.metadata import PackageNotFoundError, version from typing import Optional # Actual dependencies from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, Tracer, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter # Observers internal interfaces from observers.base import Record from observers.stores.base import Store def flatten_dict(d, prefix=""): """Flatten a python dictionary, turning nested keys into dotted keys""" flat = {} for k, v in d.items(): if v: if type(v) is dict: if prefix: flat.extend(flatten_dict(v, f"{prefix}.{k}")) else: if prefix: flat[(f"{prefix}.{k}")] = v else: flat[k] = v def get_version(): try: return version("observers") except PackageNotFoundError: return "unknown" @dataclass class OpenTelemetryStore(Store): """ OpenTelemetry Store """ # These are here largely to ease future refactors/conform to # the style of the other stores. They have defaults set in the constructor, # but, set here as well. tracer: Optional[Tracer] = None root_span: Optional[Span] = None exporter: Optional[SpanExporter] = None namespace: str = "observers.dev/observers" def __post_init__(self): if not self.tracer: provider = TracerProvider( resource=Resource.create( { "instrument.name": self.namespace, "instrument.version": get_version(), } ), ) if not self.exporter: provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) else: provider.add_span_processor(BatchSpanProcessor(self.exporter)) trace.set_tracer_provider(provider) self.tracer = trace.get_tracer(self.namespace) if not self.root_span: # if we initialize a span here, then all subsequent 'add's can be # added to a continuous trace with self.tracer.start_as_current_span(f"{self.namespace}.init") as span: span.set_attribute("connected", True) self.root_span = span def add(self, record: Record): """Add a new record to the store""" with trace.use_span(self.root_span): with self.tracer.start_as_current_span(f"{self.namespace}.add") as span: # Split out to be easily edited if the record api changes event_fields = [ "assistant_message", "completion_tokens", "total_tokens", "prompt_tokens", "finish_reason", "tool_calls", "function_call", "tags", "properties", "error", "model", "timestamp", "id", ] for field in event_fields: data = record.__getattribute__(field) if data: if type(data) is dict: intermediate = flatten_dict(data, field) for k, v in intermediate: span.set_attribute(k, v) else: span.set_attribute(field, data) # Special case for `messages` as it is a list of dicts messages = [str(message) for message in record.messages] span.set_attribute("messages", messages) @classmethod def connect(cls, tracer=None, root_span=None, namespace=None, exporter=None): """Create an ObservabilityStore, optionally starting from a prior tracer or trace, assigning a custom namespace, or setting an alternate exporter""" return cls(tracer, root_span, namespace, exporter) def _init_table(self, record: "Record"): """Initialize the dataset (no op)""" # We don't usually do this in otel, a dataset is (typically) # initialized by writing to it, but, it's part of the Store interface. async def add_async(self, record: Record): """Add a new record to the store asynchronously""" await asyncio.to_thread(self.add, record) ================================================ FILE: src/observers/stores/sql_base.py ================================================ from abc import abstractmethod from dataclasses import dataclass from typing import List, Optional from observers.stores.base import Store @dataclass class SQLStore(Store): """Base class for SQL-based stores with migration capabilities""" @abstractmethod def _check_table_exists(self, table_name: str) -> bool: """Check if a table exists in the database""" pass @abstractmethod def _create_version_table(self): """Create the schema version table""" pass @abstractmethod def _execute(self, query: str, params: Optional[List] = None): """Execute a SQL query""" pass @abstractmethod def _migrate_schema(self, migration_script: str): """Execute a migration script""" pass @abstractmethod def close(self) -> None: """Close the database connection""" pass @abstractmethod def _get_current_schema_version(self) -> int: """Get the current schema version""" pass @abstractmethod def _apply_pending_migrations(self): """Apply any pending migrations""" pass ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/conftest.py ================================================ from unittest.mock import AsyncMock, MagicMock, create_autospec import pytest from observers.stores.datasets import DatasetsStore @pytest.fixture(autouse=True) def mock_store(monkeypatch): """Mock the datasets store for all tests""" async def mock_add_async(*args, **kwargs): return None async def mock_close_async(*args, **kwargs): return None def mock_add(*args, **kwargs): return None def mock_close(*args, **kwargs): return None store_mock = create_autospec(DatasetsStore, spec_set=False, instance=True) store_mock.add_async = AsyncMock(side_effect=mock_add_async) store_mock.close_async = AsyncMock(side_effect=mock_close_async) store_mock.add = MagicMock(side_effect=mock_add) store_mock.close = MagicMock(side_effect=mock_close) def mock_connect(*args, **kwargs): return store_mock # Patch both the class and the connect method monkeypatch.setattr("observers.stores.datasets.DatasetsStore.connect", mock_connect) monkeypatch.setattr( "observers.stores.datasets.DatasetsStore", lambda *args, **kwargs: store_mock ) return store_mock ================================================ FILE: tests/integration/models/test_async_examples.py ================================================ import asyncio import os import uuid from unittest.mock import MagicMock, create_autospec import pytest from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice, CompletionUsage def get_async_example_files() -> list[str]: """Get list of asynchronous example files to test Returns: list[str]: List of paths to asynchronous example files """ examples_dir = "examples/models" if not os.path.exists(examples_dir): return [] async_files = [] for f in os.listdir(examples_dir): if not f.endswith(".py"): continue filepath = os.path.join(examples_dir, f) with open(filepath) as file: content = file.read() if "async" in content and "stream" not in content: async_files.append(filepath) return async_files @pytest.fixture def mock_clients(monkeypatch): """Fixture providing mocked API clients""" # Add async mock client async def async_openai_fake_return(*args, **kwargs): return ChatCompletion( id=str(uuid.uuid4()), choices=[ Choice( message=ChatCompletionMessage( content="", role="assistant", tool_calls=None, audio=None ), finish_reason="stop", index=0, logprobs=None, ) ], model="gpt-4", usage=CompletionUsage( prompt_tokens=10, completion_tokens=10, total_tokens=20 ), created=1727238800, object="chat.completion", system_fingerprint=None, ) async_base_mock = create_autospec(AsyncOpenAI, spec_set=False, instance=True) async_base_mock.chat = MagicMock() async_base_mock.chat.completions = MagicMock() async_base_mock.chat.completions.create = MagicMock( side_effect=async_openai_fake_return ) monkeypatch.setattr("openai.AsyncOpenAI", lambda *args, **kwargs: async_base_mock) @pytest.mark.parametrize("example_path", get_async_example_files()) @pytest.mark.asyncio async def test_async_example_files(example_path, mock_clients): """Test that async example files execute without errors""" print(f"Executing async example: {os.path.basename(example_path)}") with open(example_path) as f: content = f.read() exec_globals = {} exec(content, exec_globals) async_functions = [ f for f in exec_globals.values() if callable(f) and asyncio.iscoroutinefunction(f) ] if async_functions: await async_functions[0]() else: pytest.fail(f"No async functions found in {os.path.basename(example_path)}") ================================================ FILE: tests/integration/models/test_examples.py ================================================ import os import uuid from unittest.mock import MagicMock, patch import litellm import pytest from huggingface_hub import ChatCompletionOutput from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice, CompletionUsage def get_sync_example_files() -> list[str]: """ Get list of synchronous example files to test """ examples_dir = "examples/models" if not os.path.exists(examples_dir): return [] sync_files = [] for f in os.listdir(examples_dir): if not f.endswith(".py"): continue filepath = os.path.join(examples_dir, f) with open(filepath) as file: content = file.read() if ( "async def" not in content and "await" not in content and "stream=True" not in content ): sync_files.append(filepath) return sync_files @pytest.fixture(scope="function") def mock_clients(): """Fixture providing mocked API clients""" def openai_fake_return(*args, **kwargs): return ChatCompletion( id=str(uuid.uuid4()), choices=[ Choice( message=ChatCompletionMessage( content="", role="assistant", tool_calls=None, audio=None ), finish_reason="stop", index=0, logprobs=None, ) ], model="gpt-4", usage=CompletionUsage( prompt_tokens=10, completion_tokens=10, total_tokens=20 ), created=1727238800, object="chat.completion", system_fingerprint=None, ) def hf_fake_return(*args, **kwargs): return ChatCompletionOutput( id=str(uuid.uuid4()), model="Qwen/Qwen2.5-Coder-32B-Instruct", choices=[{"message": {"content": "Hello, world!"}}], created=1727238800, usage={"prompt_tokens": 10, "completion_tokens": 10, "total_tokens": 20}, system_fingerprint=None, ) # Create base mock for other clients base_mock = MagicMock() base_mock.chat.completions.create = MagicMock(side_effect=openai_fake_return) hf_mock = MagicMock() hf_mock.chat.completions.create = MagicMock(side_effect=hf_fake_return) mocks = { # Sync clients "openai.OpenAI": patch("openai.OpenAI", return_value=base_mock), "litellm.completion": patch("litellm.completion", litellm.mock_completion), "aisuite.Client": patch("aisuite.Client", return_value=base_mock), "huggingface_hub.InferenceClient": patch( "huggingface_hub.InferenceClient", return_value=hf_mock ), } # Start all patches for mock in mocks.values(): mock.start() yield # Stop all patches for mock in mocks.values(): mock.stop() @pytest.mark.parametrize("example_path", get_sync_example_files()) def test_sync_example_files(example_path, mock_clients): """Test that synchronous example files execute without errors""" if "async def" in open(example_path).read() or "await" in open(example_path).read(): pytest.skip("Skipping async example in sync test") print(f"Executing sync example: {os.path.basename(example_path)}") try: with open(example_path) as f: exec(f.read()) except Exception as e: pytest.fail(f"Failed to execute {os.path.basename(example_path)}: {str(e)}") ================================================ FILE: tests/integration/models/test_stream_examples.py ================================================ import asyncio import os import uuid from unittest.mock import MagicMock, create_autospec import pytest from huggingface_hub import ( AsyncInferenceClient, ChatCompletionStreamOutput, ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ) from openai import AsyncOpenAI from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk, Choice, ChoiceDelta, ) def get_async_example_files() -> list[str]: """Get list of asynchronous example files to test Returns: list[str]: List of paths to asynchronous example files """ examples_dir = "examples/models" if not os.path.exists(examples_dir): return [] async_files = [] for f in os.listdir(examples_dir): if not f.endswith(".py"): continue filepath = os.path.join(examples_dir, f) with open(filepath) as file: content = file.read() if "stream=True" in content: async_files.append(filepath) return async_files @pytest.fixture def mock_clients(monkeypatch): """Fixture providing mocked API clients""" # Add async mock client async def async_openai_fake_return(*args, **kwargs): async def async_iter(): yield ChatCompletionChunk( id=str(uuid.uuid4()), choices=[ Choice( index=0, delta=ChoiceDelta( content="chunk0", ), ) ], model="gpt-4o", usage={ "prompt_tokens": 10, "completion_tokens": 10, "total_tokens": 20, }, created=1727238800, system_fingerprint=None, object="chat.completion.chunk", ) return async_iter() async_base_mock = create_autospec(AsyncOpenAI, spec_set=False, instance=True) async_base_mock.chat = MagicMock() async_base_mock.chat.completions = MagicMock() async_base_mock.chat.completions.create = MagicMock( side_effect=async_openai_fake_return ) monkeypatch.setattr("openai.AsyncOpenAI", lambda *args, **kwargs: async_base_mock) # Add HF mock client async def hf_fake_return(*args, **kwargs): async def async_iter(): yield ChatCompletionStreamOutput( id=str(uuid.uuid4()), choices=[ ChatCompletionStreamOutputChoice( index=0, delta=ChatCompletionStreamOutputDelta( content="chunk0", role="assistant", ), ) ], model="gpt2", usage={ "prompt_tokens": 10, "completion_tokens": 10, "total_tokens": 20, }, created=1727238800, system_fingerprint=None, ) return async_iter() hf_base_mock = create_autospec(AsyncInferenceClient, spec_set=False, instance=True) hf_base_mock.chat = MagicMock() hf_base_mock.chat.completions = MagicMock() hf_base_mock.chat.completions.create = MagicMock(side_effect=hf_fake_return) monkeypatch.setattr( "huggingface_hub.AsyncInferenceClient", lambda *args, **kwargs: hf_base_mock ) @pytest.mark.parametrize("example_path", get_async_example_files()) @pytest.mark.asyncio async def test_async_example_files(example_path, mock_clients): """Test that async example files execute without errors""" print(f"Executing async example: {os.path.basename(example_path)}") with open(example_path) as f: content = f.read() exec_globals = {} exec(content, exec_globals) async_functions = [ f for f in exec_globals.values() if callable(f) and asyncio.iscoroutinefunction(f) ] if async_functions: await async_functions[0]() else: pytest.fail(f"No async functions found in {os.path.basename(example_path)}") ================================================ FILE: tests/unit/stores/test_datasets.py ================================================ import os import pytest from unittest.mock import patch from observers.stores.datasets import DatasetsStore @pytest.fixture def mock_whoami(): with patch("observers.stores.datasets.whoami") as mock: mock.return_value = {} yield mock @pytest.fixture def mock_login(): with patch("observers.stores.datasets.login") as mock: yield mock @pytest.fixture def datasets_store(mock_whoami, mock_login): store = DatasetsStore() yield store store._cleanup() def test_temp_dir_creation(datasets_store): """Test that temporary directory is created during initialization""" assert datasets_store._temp_dir is not None assert os.path.exists(datasets_store._temp_dir) def test_temp_dir_cleanup(datasets_store): """Test that temporary directory is cleaned up properly""" temp_dir = datasets_store._temp_dir assert os.path.exists(temp_dir) datasets_store._cleanup() assert not os.path.exists(temp_dir) def test_folder_path_defaults_to_temp_dir(datasets_store): """Test that folder_path defaults to temp_dir when not provided""" assert datasets_store.folder_path == datasets_store._temp_dir def test_custom_folder_path(mock_whoami, mock_login, tmp_path): """Test that custom folder_path is respected and not deleted during cleanup""" custom_path = str(tmp_path / "custom_datasets") os.makedirs(custom_path, exist_ok=True) store = DatasetsStore(folder_path=custom_path) assert store.folder_path == custom_path assert store._temp_dir is None store._cleanup() assert os.path.exists( custom_path ), "Custom folder should not be deleted during cleanup"